package service import ( "context" "next-terminal/server/global/gateway" "next-terminal/server/log" "next-terminal/server/model" "next-terminal/server/repository" ) var GatewayService = new(gatewayService) type gatewayService struct{} func (r gatewayService) GetGatewayById(accessGatewayId string) (g *gateway.Gateway, err error) { g = gateway.GlobalGatewayManager.GetById(accessGatewayId) if g != nil { return g, nil } // 先尝试从 ssh_gateways 表获取 sshGateway, err := repository.SshGatewayRepository.FindById(context.TODO(), accessGatewayId) if err == nil { return r.loadSshGateway(&sshGateway) } // 再尝试从 access_gateways 表获取 accessGateway, err := repository.GatewayRepository.FindById(context.TODO(), accessGatewayId) if err != nil { return nil, err } return r.ReLoad(&accessGateway), nil } func (r gatewayService) loadSshGateway(m *model.SshGateway) (g *gateway.Gateway, err error) { r.DisconnectById(m.ID) // 根据配置模式获取连接信息 switch m.ConfigMode { case "direct": // 直接使用配置的连接信息 g = gateway.GlobalGatewayManager.AddFromSshGateway(m) case "credential": // 从凭证获取连接信息 if m.CredentialId != "" { credential, err := CredentialService.FindByIdAndDecrypt(context.TODO(), m.CredentialId) if err != nil { return nil, err } m.Username = credential.Username m.Password = credential.Password m.PrivateKey = credential.PrivateKey m.Passphrase = credential.Passphrase g = gateway.GlobalGatewayManager.AddFromSshGateway(m) } case "asset": // 从资产获取连接信息 if m.AssetId != "" { asset, err := AssetService.FindByIdAndDecrypt(context.TODO(), m.AssetId) if err != nil { return nil, err } m.IP = asset.IP m.Port = asset.Port m.Username = asset.Username m.Password = asset.Password m.PrivateKey = asset.PrivateKey m.Passphrase = asset.Passphrase g = gateway.GlobalGatewayManager.AddFromSshGateway(m) } default: g = gateway.GlobalGatewayManager.AddFromSshGateway(m) } return g, nil } func (r gatewayService) LoadAll() error { // 加载 access_gateways gateways, err := repository.GatewayRepository.FindAll(context.TODO()) if err != nil { return err } if len(gateways) > 0 { for i := range gateways { r.ReLoad(&gateways[i]) } } // 加载 ssh_gateways sshGateways, err := repository.SshGatewayRepository.FindAll(context.TODO()) if err != nil { return err } if len(sshGateways) > 0 { for i := range sshGateways { if _, err := r.loadSshGateway(&sshGateways[i]); err != nil { log.Warn("loadSshGateway failed, id=" + sshGateways[i].ID + " err=" + err.Error()) } } } return nil } func (r gatewayService) ReLoad(m *model.AccessGateway) *gateway.Gateway { r.DisconnectById(m.ID) g := gateway.GlobalGatewayManager.Add(m) return g } func (r gatewayService) DisconnectById(id string) { gateway.GlobalGatewayManager.Del(id) }