From ae5d70b951d6e984c204cdfb90f90e7d7e33a30c Mon Sep 17 00:00:00 2001 From: laoboli <1293528695@qq.com> Date: Sat, 2 May 2026 18:08:23 +0800 Subject: [PATCH] feat: stream return. --- controllers/ai.go | 242 ++++++++++++++++++++++++++++++++------------- middleware/gzip.go | 7 ++ 2 files changed, 183 insertions(+), 66 deletions(-) diff --git a/controllers/ai.go b/controllers/ai.go index aa08a6e..6dfb890 100644 --- a/controllers/ai.go +++ b/controllers/ai.go @@ -3,7 +3,7 @@ package controllers import ( - "context" // 在此处添加 context 导入 + "context" "encoding/json" "errors" "fmt" @@ -33,37 +33,30 @@ const ( sourceWechat = "wechat" ) -// readDocxContent 读取 .docx 文件并将其转换为结构化文本 -// 修改为先保存临时文件再读取 func readDocxContent(fileHeader *multipart.FileHeader) (string, error) { - // 1. 创建临时文件 tempFile, err := os.CreateTemp("", "upload_*.docx") if err != nil { return "", fmt.Errorf("failed to create temporary file: %w", err) } - defer os.Remove(tempFile.Name()) // 确保函数结束时删除临时文件 + defer os.Remove(tempFile.Name()) defer tempFile.Close() - // 2. 打开上传的文件流 src, err := fileHeader.Open() if err != nil { return "", fmt.Errorf("failed to open uploaded file: %w", err) } defer src.Close() - // 3. 将上传的文件内容复制到临时文件 _, err = io.Copy(tempFile, src) if err != nil { return "", fmt.Errorf("failed to copy file to temporary location: %w", err) } - // 4. 获取临时文件的完整路径 tempFilePath := tempFile.Name() str, err := util.DocxToStructuredPrompt(tempFilePath) if err != nil { return "", fmt.Errorf("failed to parse docx with go-docx: %w", err) } - // 注意:表格、图片等复杂元素的处理可能需要更复杂的逻辑,这里仅处理简单文本 return str, nil } @@ -76,54 +69,41 @@ func readDocxContentFromPath(filePath string) (string, error) { return str, nil } -// readCSVContent 读取 .csv 文件内容 -// 修改为先保存临时文件再读取 -// readCSVContent 读取 .csv 文件内容 -// 修改压缩策略:每 4 行保留 1 行数据 func readCSVContent(fileHeader *multipart.FileHeader) (string, error) { - // 1. 创建临时文件 tempFile, err := os.CreateTemp("", "upload_*.csv") if err != nil { return "", fmt.Errorf("failed to create temporary file: %w", err) } - defer os.Remove(tempFile.Name()) // 确保函数结束时删除临时文件 + defer os.Remove(tempFile.Name()) defer tempFile.Close() - // 2. 打开上传的文件流 src, err := fileHeader.Open() if err != nil { return "", fmt.Errorf("failed to open uploaded file: %w", err) } defer src.Close() - // 3. 将上传的文件内容复制到临时文件 _, err = io.Copy(tempFile, src) if err != nil { return "", fmt.Errorf("failed to copy file to temporary location: %w", err) } - // 4. 读取临时文件内容 content, err := ioutil.ReadFile(tempFile.Name()) if err != nil { return "", fmt.Errorf("failed to read CSV content from temporary file: %w", err) } - // --- 修改逻辑开始:每 4 行保留 1 行 --- lines := strings.Split(string(content), "\n") var compressedLines []string for i, line := range lines { - // 1. 必须保留第一行(表头),让 AI 知道每一列是什么 if i == 0 { compressedLines = append(compressedLines, line) continue } - - // 2. 跳过空行 if strings.TrimSpace(line) == "" { continue } - if (i-1)%4 == 0 { compressedLines = append(compressedLines, line) } @@ -133,7 +113,6 @@ func readCSVContent(fileHeader *multipart.FileHeader) (string, error) { return resultContent, nil } -// buildAnalysisPrompt 构建发送给 AI 的提示词 func buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, stepContent string) string { if analysisType == analysisTypeHeartRateWithSteps { return fmt.Sprintf(`请根据以下体育课堂的教案、心率监测数据和训练结束步数汇总,生成一份详细的课堂分析报告: @@ -190,7 +169,7 @@ func buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, st | **结束部分** | 社会性及情感目标游戏 | | | | | 4 | | | 整理放松 | | | | | 2 | -请以专业体育教师的视角,提供详细的数据分析和教学建议。请直接输出报告内容,不要包含“好的”、“收到”、“作为一名...”等任何开场白或客套话。`, teachingPlanContent, heartRateContent, stepContent) +请以专业体育教师的视角,提供详细的数据分析和教学建议。请直接输出报告内容,不要包含"好的"、"收到"、"作为一名..."等任何开场白或客套话。`, teachingPlanContent, heartRateContent, stepContent) } return fmt.Sprintf(`请根据以下体育课堂的教案和心率监测数据,生成一份详细的课堂分析报告: @@ -238,7 +217,7 @@ func buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, st | **结束部分** | 社会性及情感目标游戏 | | | | | 4 | | | 整理放松 | | | | | 2 | -请以专业体育教师的视角,提供详细的数据分析和教学建议。请直接输出报告内容,不要包含“好的”、“收到”、“作为一名...”等任何开场白或客套话。`, teachingPlanContent, heartRateContent) +请以专业体育教师的视角,提供详细的数据分析和教学建议。请直接输出报告内容,不要包含"好的"、"收到"、"作为一名..."等任何开场白或客套话。`, teachingPlanContent, heartRateContent) } type aiAnalysisResult struct { @@ -251,12 +230,10 @@ type aiAnalysisResult struct { OutputSizeBytes int } -// callAIForAnalysis 调用大模型进行分析 func callAIForAnalysis(prompt string) (*aiAnalysisResult, error) { sizeInBytes := len(prompt) sizeInKB := float64(sizeInBytes) / 1024.0 - // 在日志中打印大小,保留两位小数 log.Printf("=== 发送给 AI 的内容大小: %.2f KB (%d 字节) ===", sizeInKB, sizeInBytes) baseURL, apiKey, model, err := config.GetAIConfig() if err != nil { @@ -277,9 +254,9 @@ func callAIForAnalysis(prompt string) (*aiAnalysisResult, error) { Content: prompt, }, }, - Temperature: 0.6, // 可调整 - TopP: 0.6, // 可调整 - MaxTokens: 4000, // 根据需要调整 + Temperature: 0.6, + TopP: 0.6, + MaxTokens: 4000, }, ) if err != nil { @@ -306,9 +283,7 @@ func callAIForAnalysis(prompt string) (*aiAnalysisResult, error) { }, nil } -// AnalyzeByAI Gin 控制器方法 func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { - // 1. 解析多部分表单请求 form, err := c.MultipartForm() if err != nil { log.Printf("Error parsing multipart form: %v", err) @@ -316,13 +291,14 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { return } - // 2. 获取文件列表 - csvFiles := form.File["heart_rate_data"] // 假设前端字段名为 'heart_rate_data' + csvFiles := form.File["heart_rate_data"] stepFiles := form.File["step_data"] analysisType := c.PostForm("analysis_type") teachingPlanSource := c.PostForm("teaching_plan_source") regionIDStr := c.PostForm("regionid") trainID := c.PostForm("trainid") + streamStr := c.PostForm("stream") + useStream := streamStr == "true" if analysisType == "" { analysisType = analysisTypeHeartRateOnly } @@ -341,8 +317,6 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { uploadTime := time.Now().UnixMilli() - // 3. 读取文件内容 - // 注意:这里我们只取第一个上传的文件 heartRateFileHeader := csvFiles[0] teachingPlanContent, teachingPlanSize, err := resolveTeachingPlanContent(c, form, teachingPlanSource) if err != nil { @@ -375,24 +349,13 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { } } - // 计算文件大小 originalFileSize := heartRateFileHeader.Size + teachingPlanSize + stepFileSize compressedContentSize := int64(len(heartRateContent)) + int64(len(teachingPlanContent)) + int64(len(stepContent)) - // 4. 构建 Prompt prompt := buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, stepContent) - // 5. 调用 AI 分析 startTime := time.Now() - analysisResult, err := callAIForAnalysis(prompt) - if err != nil { - log.Printf("Error calling AI for analysis: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("AI analysis failed: %v", err)}) - return - } - durationMs := time.Since(startTime).Milliseconds() - // 6. 保存分析记录 var regionID *uint32 if regionIDStr != "" { if parsed, err := strconv.ParseUint(regionIDStr, 10, 32); err == nil { @@ -401,7 +364,160 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { } } - // 计算费用 + if useStream { + tc.streamAIAnalysis(c, prompt, regionID, trainID, teachingPlanSource, analysisType, + originalFileSize, compressedContentSize, uploadTime) + return + } + + analysisResult, err := callAIForAnalysis(prompt) + if err != nil { + log.Printf("Error calling AI for analysis: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("AI analysis failed: %v", err)}) + return + } + durationMs := time.Since(startTime).Milliseconds() + + saveAnalysisRecord(analysisResult.Content, analysisResult.InputTokens, analysisResult.OutputTokens, + analysisResult.CacheHitTokens, analysisResult.CacheMissTokens, + analysisResult.InputSizeBytes, analysisResult.OutputSizeBytes, + regionID, trainID, teachingPlanSource, analysisType, + originalFileSize, compressedContentSize, uploadTime, durationMs) + + c.JSON(http.StatusOK, gin.H{ + "status": "success", + "data": analysisResult.Content, + }) +} + +type streamCollector struct { + fullContent string + inputTokens int + outputTokens int + cacheHitTokens int + cacheMissTokens int +} + +func newStreamCollector() *streamCollector { + return &streamCollector{} +} + +func (sc *streamCollector) add(delta string) { + sc.fullContent += delta +} + +func (sc *streamCollector) updateUsage(usage *openai.Usage) { + sc.inputTokens = usage.PromptTokens + sc.outputTokens = usage.CompletionTokens + if usage.PromptTokensDetails != nil { + sc.cacheHitTokens = usage.PromptTokensDetails.CachedTokens + } + sc.cacheMissTokens = sc.inputTokens - sc.cacheHitTokens +} + +func (tc *TrainingController) streamAIAnalysis(c *gin.Context, prompt string, + regionID *uint32, trainID, sourceType, analysisType string, + originalFileSize, compressedContentSize int64, uploadTime int64) { + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.WriteHeader(http.StatusOK) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + log.Printf("streaming not supported") + c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"}) + return + } + + baseURL, apiKey, model, err := config.GetAIConfig() + if err != nil { + sendSSEError(c, err.Error()) + return + } + + clientConfig := openai.DefaultConfig(apiKey) + clientConfig.BaseURL = baseURL + client := openai.NewClientWithConfig(clientConfig) + + stream, err := client.CreateChatCompletionStream( + c.Request.Context(), + openai.ChatCompletionRequest{ + Model: model, + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: prompt}, + }, + Temperature: 0.6, + TopP: 0.6, + MaxTokens: 4000, + Stream: true, + }, + ) + if err != nil { + sendSSEError(c, fmt.Sprintf("stream failed: %v", err)) + return + } + defer stream.Close() + + startTime := time.Now() + collector := newStreamCollector() + + for { + response, recvErr := stream.Recv() + if recvErr != nil { + if recvErr == io.EOF { + break + } + sendSSEError(c, fmt.Sprintf("stream recv error: %v", recvErr)) + return + } + if len(response.Choices) > 0 { + delta := response.Choices[0].Delta.Content + collector.add(delta) + sendSSEData(c, map[string]interface{}{"content": delta}) + flusher.Flush() + } + if response.Usage != nil { + collector.updateUsage(response.Usage) + } + } + + durationMs := time.Since(startTime).Milliseconds() + + saveAnalysisRecord(collector.fullContent, collector.inputTokens, collector.outputTokens, + collector.cacheHitTokens, collector.cacheMissTokens, + len(prompt), len(collector.fullContent), + regionID, trainID, sourceType, analysisType, + originalFileSize, compressedContentSize, uploadTime, durationMs) + + sendSSEData(c, map[string]interface{}{ + "done": true, + "inputTokens": collector.inputTokens, + "outputTokens": collector.outputTokens, + "cacheHitTokens": collector.cacheHitTokens, + }) + flusher.Flush() +} + +func sendSSEData(c *gin.Context, data map[string]interface{}) { + b, _ := json.Marshal(data) + fmt.Fprintf(c.Writer, "data: %s\n\n", string(b)) +} + +func sendSSEError(c *gin.Context, msg string) { + b, _ := json.Marshal(map[string]string{"error": msg}) + fmt.Fprintf(c.Writer, "data: %s\n\n", string(b)) + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } +} + +func saveAnalysisRecord(content string, inputTokens, outputTokens, cacheHitTokens, cacheMissTokens, + inputSizeBytes, outputSizeBytes int, + regionID *uint32, trainID, sourceType, analysisType string, + originalFileSize, compressedContentSize int64, uploadTime int64, durationMs int64) { + var pricing models.AIPricingConfig var costJSON string var totalCost float64 @@ -414,9 +530,9 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { if cacheHitPrice == 0 { cacheHitPrice = pricing.InputPricePerMillion } - cacheHitCost := float64(analysisResult.CacheHitTokens) * cacheHitPrice / 1_000_000 - cacheMissCost := float64(analysisResult.CacheMissTokens) * cacheMissPrice / 1_000_000 - outputCost := float64(analysisResult.OutputTokens) * pricing.OutputPricePerMillion / 1_000_000 + cacheHitCost := float64(cacheHitTokens) * cacheHitPrice / 1_000_000 + cacheMissCost := float64(cacheMissTokens) * cacheMissPrice / 1_000_000 + outputCost := float64(outputTokens) * pricing.OutputPricePerMillion / 1_000_000 totalCost = cacheHitCost + cacheMissCost + outputCost costInfo := map[string]interface{}{ @@ -438,17 +554,17 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { record := models.AIAnalysisRecord{ RegionID: regionID, TrainId: trainID, - SourceType: teachingPlanSource, + SourceType: sourceType, AnalysisType: analysisType, - AnalysisResult: analysisResult.Content, + AnalysisResult: content, CostJSON: costJSON, TotalCost: totalCost, - InputTokens: analysisResult.InputTokens, - OutputTokens: analysisResult.OutputTokens, - CacheHitTokens: analysisResult.CacheHitTokens, - CacheMissTokens: analysisResult.CacheMissTokens, - InputSizeBytes: analysisResult.InputSizeBytes, - OutputSizeBytes: analysisResult.OutputSizeBytes, + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheHitTokens: cacheHitTokens, + CacheMissTokens: cacheMissTokens, + InputSizeBytes: inputSizeBytes, + OutputSizeBytes: outputSizeBytes, DurationMs: durationMs, OriginalFileSize: originalFileSize, CompressedContentSize: compressedContentSize, @@ -457,12 +573,6 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { if err := config.DB.Create(&record).Error; err != nil { log.Printf("Failed to save analysis record: %v", err) } - - // 7. 返回结果 - c.JSON(http.StatusOK, gin.H{ - "status": "success", - "data": analysisResult.Content, - }) } func resolveTeachingPlanContent(c *gin.Context, form *multipart.Form, source string) (string, int64, error) { diff --git a/middleware/gzip.go b/middleware/gzip.go index 6259a02..78e4829 100644 --- a/middleware/gzip.go +++ b/middleware/gzip.go @@ -55,3 +55,10 @@ func (w *gzipResponseWriter) Write(data []byte) (int, error) { func (w *gzipResponseWriter) WriteString(s string) (int, error) { return w.gzWriter.Write([]byte(s)) } + +func (w *gzipResponseWriter) Flush() { + w.gzWriter.Flush() + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +}