package api import ( "bytes" "context" "strings" "sync" "time" "unicode/utf8" "github.com/gorilla/websocket" "next-terminal/server/common" "next-terminal/server/common/term" "next-terminal/server/dto" "next-terminal/server/global/session" "next-terminal/server/model" "next-terminal/server/repository" "next-terminal/server/utils" ) const maxOutputLength = 10000 type TermHandler struct { sessionId string userId string assetId string isRecording bool webSocket *websocket.Conn nextTerminal *term.NextTerminal ctx context.Context cancel context.CancelFunc dataChan chan rune tick *time.Ticker mutex sync.Mutex buf bytes.Buffer commandBuf bytes.Buffer outputBuf bytes.Buffer lastCommand *model.SessionCommand outputMutex sync.Mutex } func NewTermHandler(userId, assetId, sessionId string, isRecording bool, ws *websocket.Conn, nextTerminal *term.NextTerminal) *TermHandler { ctx, cancel := context.WithCancel(context.Background()) tick := time.NewTicker(time.Millisecond * time.Duration(60)) return &TermHandler{ sessionId: sessionId, userId: userId, assetId: assetId, isRecording: isRecording, webSocket: ws, nextTerminal: nextTerminal, ctx: ctx, cancel: cancel, dataChan: make(chan rune), tick: tick, } } func (r *TermHandler) Start() { go r.readFormTunnel() go r.writeToWebsocket() } func (r *TermHandler) Stop() { r.tick.Stop() r.cancel() r.saveLastCommandOutput() } func (r *TermHandler) readFormTunnel() { for { select { case <-r.ctx.Done(): return default: rn, size, err := r.nextTerminal.StdoutReader.ReadRune() if err != nil { return } if size > 0 { r.dataChan <- rn } } } } func (r *TermHandler) writeToWebsocket() { for { select { case <-r.ctx.Done(): return case <-r.tick.C: s := r.buf.String() if s == "" { continue } if err := r.SendMessageToWebSocket(dto.NewMessage(Data, s)); err != nil { return } if r.isRecording && r.nextTerminal.Recorder != nil { _ = r.nextTerminal.Recorder.WriteData(s) } SendObData(r.sessionId, s) r.collectOutput(s) r.buf.Reset() case data := <-r.dataChan: if data != utf8.RuneError { p := make([]byte, utf8.RuneLen(data)) utf8.EncodeRune(p, data) r.buf.Write(p) } else { r.buf.Write([]byte("@")) } } } } func (r *TermHandler) collectOutput(output string) { r.outputMutex.Lock() defer r.outputMutex.Unlock() if r.outputBuf.Len()+len(output) <= maxOutputLength { r.outputBuf.WriteString(output) } } func (r *TermHandler) saveLastCommandOutput() { r.outputMutex.Lock() defer r.outputMutex.Unlock() if r.lastCommand != nil && r.outputBuf.Len() > 0 { output := cleanOutput(r.outputBuf.String()) if len(output) > maxOutputLength { output = output[:maxOutputLength] } r.lastCommand.Output = output _ = repository.SessionCommandRepository.Create(context.TODO(), r.lastCommand) r.lastCommand = nil r.outputBuf.Reset() } } func cleanOutput(output string) string { output = strings.ReplaceAll(output, "\x1b[0m", "") output = strings.ReplaceAll(output, "\x1b[?2004h", "") output = strings.ReplaceAll(output, "\x1b[?2004l", "") var result strings.Builder for _, line := range strings.Split(output, "\n") { line = strings.TrimSpace(line) if line != "" { result.WriteString(line) result.WriteString("\n") } } return strings.TrimSpace(result.String()) } func (r *TermHandler) Write(input []byte) error { for _, b := range input { switch b { case 13: // Enter key command := strings.TrimSpace(r.commandBuf.String()) if command != "" { r.saveLastCommandOutput() r.lastCommand = &model.SessionCommand{ ID: utils.UUID(), SessionId: r.sessionId, RiskLevel: detectRiskLevel(command), Command: command, Output: "", Created: common.NowJsonTime(), } } r.commandBuf.Reset() case 127, 8: // Backspace, Delete if r.commandBuf.Len() > 0 { r.commandBuf.Truncate(r.commandBuf.Len() - 1) } case 27: // Escape sequence (arrow keys, etc.) default: if b >= 32 && b < 127 { // Printable ASCII r.commandBuf.WriteByte(b) } } } _, err := r.nextTerminal.Write(input) return err } func detectRiskLevel(command string) int { highRiskPatterns := []string{ "rm -rf /", "rm -rf /*", "mkfs", "dd if=", "> /dev/sd", "chmod -R 777", "chown -R", ":(){ :|:& };:", } mediumRiskPatterns := []string{ "rm -rf", "rm -r", "passwd", "useradd", "userdel", "chmod 777", "chmod -R", "chown", "wget", "curl", "apt-get", "yum install", "dnf install", "pacman", "systemctl stop", "systemctl disable", "service stop", "iptables -F", "ufw disable", "setenforce 0", } lowerCmd := strings.ToLower(command) for _, pattern := range highRiskPatterns { if strings.Contains(lowerCmd, strings.ToLower(pattern)) { return 2 } } for _, pattern := range mediumRiskPatterns { if strings.Contains(lowerCmd, strings.ToLower(pattern)) { return 1 } } return 0 } func (r *TermHandler) WindowChange(h int, w int) error { return r.nextTerminal.WindowChange(h, w) } func (r *TermHandler) SendRequest() error { _, _, err := r.nextTerminal.SshClient.Conn.SendRequest("helloworld1024@foxmail.com", true, nil) return err } func (r *TermHandler) SendMessageToWebSocket(msg dto.Message) error { if r.webSocket == nil { return nil } defer r.mutex.Unlock() r.mutex.Lock() message := []byte(msg.ToString()) return r.webSocket.WriteMessage(websocket.TextMessage, message) } func SendObData(sessionId, s string) { nextSession := session.GlobalSessionManager.GetById(sessionId) if nextSession != nil && nextSession.Observer != nil { nextSession.Observer.Range(func(key string, ob *session.Session) { _ = ob.WriteMessage(dto.NewMessage(Data, s)) }) } }