feat: stream return.
This commit is contained in:
+176
-66
@@ -3,7 +3,7 @@
|
|||||||
package controllers
|
package controllers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context" // 在此处添加 context 导入
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -33,37 +33,30 @@ const (
|
|||||||
sourceWechat = "wechat"
|
sourceWechat = "wechat"
|
||||||
)
|
)
|
||||||
|
|
||||||
// readDocxContent 读取 .docx 文件并将其转换为结构化文本
|
|
||||||
// 修改为先保存临时文件再读取
|
|
||||||
func readDocxContent(fileHeader *multipart.FileHeader) (string, error) {
|
func readDocxContent(fileHeader *multipart.FileHeader) (string, error) {
|
||||||
// 1. 创建临时文件
|
|
||||||
tempFile, err := os.CreateTemp("", "upload_*.docx")
|
tempFile, err := os.CreateTemp("", "upload_*.docx")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to create temporary file: %w", err)
|
return "", fmt.Errorf("failed to create temporary file: %w", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tempFile.Name()) // 确保函数结束时删除临时文件
|
defer os.Remove(tempFile.Name())
|
||||||
defer tempFile.Close()
|
defer tempFile.Close()
|
||||||
|
|
||||||
// 2. 打开上传的文件流
|
|
||||||
src, err := fileHeader.Open()
|
src, err := fileHeader.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to open uploaded file: %w", err)
|
return "", fmt.Errorf("failed to open uploaded file: %w", err)
|
||||||
}
|
}
|
||||||
defer src.Close()
|
defer src.Close()
|
||||||
|
|
||||||
// 3. 将上传的文件内容复制到临时文件
|
|
||||||
_, err = io.Copy(tempFile, src)
|
_, err = io.Copy(tempFile, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to copy file to temporary location: %w", err)
|
return "", fmt.Errorf("failed to copy file to temporary location: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 获取临时文件的完整路径
|
|
||||||
tempFilePath := tempFile.Name()
|
tempFilePath := tempFile.Name()
|
||||||
str, err := util.DocxToStructuredPrompt(tempFilePath)
|
str, err := util.DocxToStructuredPrompt(tempFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to parse docx with go-docx: %w", err)
|
return "", fmt.Errorf("failed to parse docx with go-docx: %w", err)
|
||||||
}
|
}
|
||||||
// 注意:表格、图片等复杂元素的处理可能需要更复杂的逻辑,这里仅处理简单文本
|
|
||||||
|
|
||||||
return str, nil
|
return str, nil
|
||||||
}
|
}
|
||||||
@@ -76,54 +69,41 @@ func readDocxContentFromPath(filePath string) (string, error) {
|
|||||||
return str, nil
|
return str, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// readCSVContent 读取 .csv 文件内容
|
|
||||||
// 修改为先保存临时文件再读取
|
|
||||||
// readCSVContent 读取 .csv 文件内容
|
|
||||||
// 修改压缩策略:每 4 行保留 1 行数据
|
|
||||||
func readCSVContent(fileHeader *multipart.FileHeader) (string, error) {
|
func readCSVContent(fileHeader *multipart.FileHeader) (string, error) {
|
||||||
// 1. 创建临时文件
|
|
||||||
tempFile, err := os.CreateTemp("", "upload_*.csv")
|
tempFile, err := os.CreateTemp("", "upload_*.csv")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to create temporary file: %w", err)
|
return "", fmt.Errorf("failed to create temporary file: %w", err)
|
||||||
}
|
}
|
||||||
defer os.Remove(tempFile.Name()) // 确保函数结束时删除临时文件
|
defer os.Remove(tempFile.Name())
|
||||||
defer tempFile.Close()
|
defer tempFile.Close()
|
||||||
|
|
||||||
// 2. 打开上传的文件流
|
|
||||||
src, err := fileHeader.Open()
|
src, err := fileHeader.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to open uploaded file: %w", err)
|
return "", fmt.Errorf("failed to open uploaded file: %w", err)
|
||||||
}
|
}
|
||||||
defer src.Close()
|
defer src.Close()
|
||||||
|
|
||||||
// 3. 将上传的文件内容复制到临时文件
|
|
||||||
_, err = io.Copy(tempFile, src)
|
_, err = io.Copy(tempFile, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to copy file to temporary location: %w", err)
|
return "", fmt.Errorf("failed to copy file to temporary location: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 读取临时文件内容
|
|
||||||
content, err := ioutil.ReadFile(tempFile.Name())
|
content, err := ioutil.ReadFile(tempFile.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to read CSV content from temporary file: %w", err)
|
return "", fmt.Errorf("failed to read CSV content from temporary file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- 修改逻辑开始:每 4 行保留 1 行 ---
|
|
||||||
lines := strings.Split(string(content), "\n")
|
lines := strings.Split(string(content), "\n")
|
||||||
var compressedLines []string
|
var compressedLines []string
|
||||||
|
|
||||||
for i, line := range lines {
|
for i, line := range lines {
|
||||||
// 1. 必须保留第一行(表头),让 AI 知道每一列是什么
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
compressedLines = append(compressedLines, line)
|
compressedLines = append(compressedLines, line)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 跳过空行
|
|
||||||
if strings.TrimSpace(line) == "" {
|
if strings.TrimSpace(line) == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i-1)%4 == 0 {
|
if (i-1)%4 == 0 {
|
||||||
compressedLines = append(compressedLines, line)
|
compressedLines = append(compressedLines, line)
|
||||||
}
|
}
|
||||||
@@ -133,7 +113,6 @@ func readCSVContent(fileHeader *multipart.FileHeader) (string, error) {
|
|||||||
return resultContent, nil
|
return resultContent, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildAnalysisPrompt 构建发送给 AI 的提示词
|
|
||||||
func buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, stepContent string) string {
|
func buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, stepContent string) string {
|
||||||
if analysisType == analysisTypeHeartRateWithSteps {
|
if analysisType == analysisTypeHeartRateWithSteps {
|
||||||
return fmt.Sprintf(`请根据以下体育课堂的教案、心率监测数据和训练结束步数汇总,生成一份详细的课堂分析报告:
|
return fmt.Sprintf(`请根据以下体育课堂的教案、心率监测数据和训练结束步数汇总,生成一份详细的课堂分析报告:
|
||||||
@@ -190,7 +169,7 @@ func buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, st
|
|||||||
| **结束部分** | 社会性及情感目标游戏 | | | | | 4 |
|
| **结束部分** | 社会性及情感目标游戏 | | | | | 4 |
|
||||||
| | 整理放松 | | | | | 2 |
|
| | 整理放松 | | | | | 2 |
|
||||||
|
|
||||||
请以专业体育教师的视角,提供详细的数据分析和教学建议。请直接输出报告内容,不要包含“好的”、“收到”、“作为一名...”等任何开场白或客套话。`, teachingPlanContent, heartRateContent, stepContent)
|
请以专业体育教师的视角,提供详细的数据分析和教学建议。请直接输出报告内容,不要包含"好的"、"收到"、"作为一名..."等任何开场白或客套话。`, teachingPlanContent, heartRateContent, stepContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf(`请根据以下体育课堂的教案和心率监测数据,生成一份详细的课堂分析报告:
|
return fmt.Sprintf(`请根据以下体育课堂的教案和心率监测数据,生成一份详细的课堂分析报告:
|
||||||
@@ -238,7 +217,7 @@ func buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, st
|
|||||||
| **结束部分** | 社会性及情感目标游戏 | | | | | 4 |
|
| **结束部分** | 社会性及情感目标游戏 | | | | | 4 |
|
||||||
| | 整理放松 | | | | | 2 |
|
| | 整理放松 | | | | | 2 |
|
||||||
|
|
||||||
请以专业体育教师的视角,提供详细的数据分析和教学建议。请直接输出报告内容,不要包含“好的”、“收到”、“作为一名...”等任何开场白或客套话。`, teachingPlanContent, heartRateContent)
|
请以专业体育教师的视角,提供详细的数据分析和教学建议。请直接输出报告内容,不要包含"好的"、"收到"、"作为一名..."等任何开场白或客套话。`, teachingPlanContent, heartRateContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
type aiAnalysisResult struct {
|
type aiAnalysisResult struct {
|
||||||
@@ -251,12 +230,10 @@ type aiAnalysisResult struct {
|
|||||||
OutputSizeBytes int
|
OutputSizeBytes int
|
||||||
}
|
}
|
||||||
|
|
||||||
// callAIForAnalysis 调用大模型进行分析
|
|
||||||
func callAIForAnalysis(prompt string) (*aiAnalysisResult, error) {
|
func callAIForAnalysis(prompt string) (*aiAnalysisResult, error) {
|
||||||
sizeInBytes := len(prompt)
|
sizeInBytes := len(prompt)
|
||||||
sizeInKB := float64(sizeInBytes) / 1024.0
|
sizeInKB := float64(sizeInBytes) / 1024.0
|
||||||
|
|
||||||
// 在日志中打印大小,保留两位小数
|
|
||||||
log.Printf("=== 发送给 AI 的内容大小: %.2f KB (%d 字节) ===", sizeInKB, sizeInBytes)
|
log.Printf("=== 发送给 AI 的内容大小: %.2f KB (%d 字节) ===", sizeInKB, sizeInBytes)
|
||||||
baseURL, apiKey, model, err := config.GetAIConfig()
|
baseURL, apiKey, model, err := config.GetAIConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -277,9 +254,9 @@ func callAIForAnalysis(prompt string) (*aiAnalysisResult, error) {
|
|||||||
Content: prompt,
|
Content: prompt,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Temperature: 0.6, // 可调整
|
Temperature: 0.6,
|
||||||
TopP: 0.6, // 可调整
|
TopP: 0.6,
|
||||||
MaxTokens: 4000, // 根据需要调整
|
MaxTokens: 4000,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -306,9 +283,7 @@ func callAIForAnalysis(prompt string) (*aiAnalysisResult, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AnalyzeByAI Gin 控制器方法
|
|
||||||
func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
||||||
// 1. 解析多部分表单请求
|
|
||||||
form, err := c.MultipartForm()
|
form, err := c.MultipartForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error parsing multipart form: %v", err)
|
log.Printf("Error parsing multipart form: %v", err)
|
||||||
@@ -316,13 +291,14 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 获取文件列表
|
csvFiles := form.File["heart_rate_data"]
|
||||||
csvFiles := form.File["heart_rate_data"] // 假设前端字段名为 'heart_rate_data'
|
|
||||||
stepFiles := form.File["step_data"]
|
stepFiles := form.File["step_data"]
|
||||||
analysisType := c.PostForm("analysis_type")
|
analysisType := c.PostForm("analysis_type")
|
||||||
teachingPlanSource := c.PostForm("teaching_plan_source")
|
teachingPlanSource := c.PostForm("teaching_plan_source")
|
||||||
regionIDStr := c.PostForm("regionid")
|
regionIDStr := c.PostForm("regionid")
|
||||||
trainID := c.PostForm("trainid")
|
trainID := c.PostForm("trainid")
|
||||||
|
streamStr := c.PostForm("stream")
|
||||||
|
useStream := streamStr == "true"
|
||||||
if analysisType == "" {
|
if analysisType == "" {
|
||||||
analysisType = analysisTypeHeartRateOnly
|
analysisType = analysisTypeHeartRateOnly
|
||||||
}
|
}
|
||||||
@@ -341,8 +317,6 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
|||||||
|
|
||||||
uploadTime := time.Now().UnixMilli()
|
uploadTime := time.Now().UnixMilli()
|
||||||
|
|
||||||
// 3. 读取文件内容
|
|
||||||
// 注意:这里我们只取第一个上传的文件
|
|
||||||
heartRateFileHeader := csvFiles[0]
|
heartRateFileHeader := csvFiles[0]
|
||||||
teachingPlanContent, teachingPlanSize, err := resolveTeachingPlanContent(c, form, teachingPlanSource)
|
teachingPlanContent, teachingPlanSize, err := resolveTeachingPlanContent(c, form, teachingPlanSource)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -375,24 +349,13 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算文件大小
|
|
||||||
originalFileSize := heartRateFileHeader.Size + teachingPlanSize + stepFileSize
|
originalFileSize := heartRateFileHeader.Size + teachingPlanSize + stepFileSize
|
||||||
compressedContentSize := int64(len(heartRateContent)) + int64(len(teachingPlanContent)) + int64(len(stepContent))
|
compressedContentSize := int64(len(heartRateContent)) + int64(len(teachingPlanContent)) + int64(len(stepContent))
|
||||||
|
|
||||||
// 4. 构建 Prompt
|
|
||||||
prompt := buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, stepContent)
|
prompt := buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, stepContent)
|
||||||
|
|
||||||
// 5. 调用 AI 分析
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
analysisResult, err := callAIForAnalysis(prompt)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error calling AI for analysis: %v", err)
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("AI analysis failed: %v", err)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
durationMs := time.Since(startTime).Milliseconds()
|
|
||||||
|
|
||||||
// 6. 保存分析记录
|
|
||||||
var regionID *uint32
|
var regionID *uint32
|
||||||
if regionIDStr != "" {
|
if regionIDStr != "" {
|
||||||
if parsed, err := strconv.ParseUint(regionIDStr, 10, 32); err == nil {
|
if parsed, err := strconv.ParseUint(regionIDStr, 10, 32); err == nil {
|
||||||
@@ -401,7 +364,160 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算费用
|
if useStream {
|
||||||
|
tc.streamAIAnalysis(c, prompt, regionID, trainID, teachingPlanSource, analysisType,
|
||||||
|
originalFileSize, compressedContentSize, uploadTime)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
analysisResult, err := callAIForAnalysis(prompt)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error calling AI for analysis: %v", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("AI analysis failed: %v", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
durationMs := time.Since(startTime).Milliseconds()
|
||||||
|
|
||||||
|
saveAnalysisRecord(analysisResult.Content, analysisResult.InputTokens, analysisResult.OutputTokens,
|
||||||
|
analysisResult.CacheHitTokens, analysisResult.CacheMissTokens,
|
||||||
|
analysisResult.InputSizeBytes, analysisResult.OutputSizeBytes,
|
||||||
|
regionID, trainID, teachingPlanSource, analysisType,
|
||||||
|
originalFileSize, compressedContentSize, uploadTime, durationMs)
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": "success",
|
||||||
|
"data": analysisResult.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamCollector struct {
|
||||||
|
fullContent string
|
||||||
|
inputTokens int
|
||||||
|
outputTokens int
|
||||||
|
cacheHitTokens int
|
||||||
|
cacheMissTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStreamCollector() *streamCollector {
|
||||||
|
return &streamCollector{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *streamCollector) add(delta string) {
|
||||||
|
sc.fullContent += delta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *streamCollector) updateUsage(usage *openai.Usage) {
|
||||||
|
sc.inputTokens = usage.PromptTokens
|
||||||
|
sc.outputTokens = usage.CompletionTokens
|
||||||
|
if usage.PromptTokensDetails != nil {
|
||||||
|
sc.cacheHitTokens = usage.PromptTokensDetails.CachedTokens
|
||||||
|
}
|
||||||
|
sc.cacheMissTokens = sc.inputTokens - sc.cacheHitTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *TrainingController) streamAIAnalysis(c *gin.Context, prompt string,
|
||||||
|
regionID *uint32, trainID, sourceType, analysisType string,
|
||||||
|
originalFileSize, compressedContentSize int64, uploadTime int64) {
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("streaming not supported")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming not supported"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL, apiKey, model, err := config.GetAIConfig()
|
||||||
|
if err != nil {
|
||||||
|
sendSSEError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConfig := openai.DefaultConfig(apiKey)
|
||||||
|
clientConfig.BaseURL = baseURL
|
||||||
|
client := openai.NewClientWithConfig(clientConfig)
|
||||||
|
|
||||||
|
stream, err := client.CreateChatCompletionStream(
|
||||||
|
c.Request.Context(),
|
||||||
|
openai.ChatCompletionRequest{
|
||||||
|
Model: model,
|
||||||
|
Messages: []openai.ChatCompletionMessage{
|
||||||
|
{Role: openai.ChatMessageRoleUser, Content: prompt},
|
||||||
|
},
|
||||||
|
Temperature: 0.6,
|
||||||
|
TopP: 0.6,
|
||||||
|
MaxTokens: 4000,
|
||||||
|
Stream: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
sendSSEError(c, fmt.Sprintf("stream failed: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
collector := newStreamCollector()
|
||||||
|
|
||||||
|
for {
|
||||||
|
response, recvErr := stream.Recv()
|
||||||
|
if recvErr != nil {
|
||||||
|
if recvErr == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
sendSSEError(c, fmt.Sprintf("stream recv error: %v", recvErr))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(response.Choices) > 0 {
|
||||||
|
delta := response.Choices[0].Delta.Content
|
||||||
|
collector.add(delta)
|
||||||
|
sendSSEData(c, map[string]interface{}{"content": delta})
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
if response.Usage != nil {
|
||||||
|
collector.updateUsage(response.Usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
durationMs := time.Since(startTime).Milliseconds()
|
||||||
|
|
||||||
|
saveAnalysisRecord(collector.fullContent, collector.inputTokens, collector.outputTokens,
|
||||||
|
collector.cacheHitTokens, collector.cacheMissTokens,
|
||||||
|
len(prompt), len(collector.fullContent),
|
||||||
|
regionID, trainID, sourceType, analysisType,
|
||||||
|
originalFileSize, compressedContentSize, uploadTime, durationMs)
|
||||||
|
|
||||||
|
sendSSEData(c, map[string]interface{}{
|
||||||
|
"done": true,
|
||||||
|
"inputTokens": collector.inputTokens,
|
||||||
|
"outputTokens": collector.outputTokens,
|
||||||
|
"cacheHitTokens": collector.cacheHitTokens,
|
||||||
|
})
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendSSEData(c *gin.Context, data map[string]interface{}) {
|
||||||
|
b, _ := json.Marshal(data)
|
||||||
|
fmt.Fprintf(c.Writer, "data: %s\n\n", string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendSSEError(c *gin.Context, msg string) {
|
||||||
|
b, _ := json.Marshal(map[string]string{"error": msg})
|
||||||
|
fmt.Fprintf(c.Writer, "data: %s\n\n", string(b))
|
||||||
|
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveAnalysisRecord(content string, inputTokens, outputTokens, cacheHitTokens, cacheMissTokens,
|
||||||
|
inputSizeBytes, outputSizeBytes int,
|
||||||
|
regionID *uint32, trainID, sourceType, analysisType string,
|
||||||
|
originalFileSize, compressedContentSize int64, uploadTime int64, durationMs int64) {
|
||||||
|
|
||||||
var pricing models.AIPricingConfig
|
var pricing models.AIPricingConfig
|
||||||
var costJSON string
|
var costJSON string
|
||||||
var totalCost float64
|
var totalCost float64
|
||||||
@@ -414,9 +530,9 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
|||||||
if cacheHitPrice == 0 {
|
if cacheHitPrice == 0 {
|
||||||
cacheHitPrice = pricing.InputPricePerMillion
|
cacheHitPrice = pricing.InputPricePerMillion
|
||||||
}
|
}
|
||||||
cacheHitCost := float64(analysisResult.CacheHitTokens) * cacheHitPrice / 1_000_000
|
cacheHitCost := float64(cacheHitTokens) * cacheHitPrice / 1_000_000
|
||||||
cacheMissCost := float64(analysisResult.CacheMissTokens) * cacheMissPrice / 1_000_000
|
cacheMissCost := float64(cacheMissTokens) * cacheMissPrice / 1_000_000
|
||||||
outputCost := float64(analysisResult.OutputTokens) * pricing.OutputPricePerMillion / 1_000_000
|
outputCost := float64(outputTokens) * pricing.OutputPricePerMillion / 1_000_000
|
||||||
totalCost = cacheHitCost + cacheMissCost + outputCost
|
totalCost = cacheHitCost + cacheMissCost + outputCost
|
||||||
|
|
||||||
costInfo := map[string]interface{}{
|
costInfo := map[string]interface{}{
|
||||||
@@ -438,17 +554,17 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
|||||||
record := models.AIAnalysisRecord{
|
record := models.AIAnalysisRecord{
|
||||||
RegionID: regionID,
|
RegionID: regionID,
|
||||||
TrainId: trainID,
|
TrainId: trainID,
|
||||||
SourceType: teachingPlanSource,
|
SourceType: sourceType,
|
||||||
AnalysisType: analysisType,
|
AnalysisType: analysisType,
|
||||||
AnalysisResult: analysisResult.Content,
|
AnalysisResult: content,
|
||||||
CostJSON: costJSON,
|
CostJSON: costJSON,
|
||||||
TotalCost: totalCost,
|
TotalCost: totalCost,
|
||||||
InputTokens: analysisResult.InputTokens,
|
InputTokens: inputTokens,
|
||||||
OutputTokens: analysisResult.OutputTokens,
|
OutputTokens: outputTokens,
|
||||||
CacheHitTokens: analysisResult.CacheHitTokens,
|
CacheHitTokens: cacheHitTokens,
|
||||||
CacheMissTokens: analysisResult.CacheMissTokens,
|
CacheMissTokens: cacheMissTokens,
|
||||||
InputSizeBytes: analysisResult.InputSizeBytes,
|
InputSizeBytes: inputSizeBytes,
|
||||||
OutputSizeBytes: analysisResult.OutputSizeBytes,
|
OutputSizeBytes: outputSizeBytes,
|
||||||
DurationMs: durationMs,
|
DurationMs: durationMs,
|
||||||
OriginalFileSize: originalFileSize,
|
OriginalFileSize: originalFileSize,
|
||||||
CompressedContentSize: compressedContentSize,
|
CompressedContentSize: compressedContentSize,
|
||||||
@@ -457,12 +573,6 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
|||||||
if err := config.DB.Create(&record).Error; err != nil {
|
if err := config.DB.Create(&record).Error; err != nil {
|
||||||
log.Printf("Failed to save analysis record: %v", err)
|
log.Printf("Failed to save analysis record: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. 返回结果
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"status": "success",
|
|
||||||
"data": analysisResult.Content,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveTeachingPlanContent(c *gin.Context, form *multipart.Form, source string) (string, int64, error) {
|
func resolveTeachingPlanContent(c *gin.Context, form *multipart.Form, source string) (string, int64, error) {
|
||||||
|
|||||||
@@ -55,3 +55,10 @@ func (w *gzipResponseWriter) Write(data []byte) (int, error) {
|
|||||||
func (w *gzipResponseWriter) WriteString(s string) (int, error) {
|
func (w *gzipResponseWriter) WriteString(s string) (int, error) {
|
||||||
return w.gzWriter.Write([]byte(s))
|
return w.gzWriter.Write([]byte(s))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *gzipResponseWriter) Flush() {
|
||||||
|
w.gzWriter.Flush()
|
||||||
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user