From c2bb69bde67dce18b20837c7dd3e27677ced145a Mon Sep 17 00:00:00 2001 From: laoboli <1293528695@qq.com> Date: Wed, 29 Apr 2026 20:32:14 +0800 Subject: [PATCH] refactor: ai usage statics. --- controllers/ai.go | 89 +++++++++++---- controllers/statistics.go | 222 ++++++++++++++++++++++++++++++++++++++ main.go | 1 + models/analyze.go | 17 +++ routes/routes.go | 5 + 5 files changed, 315 insertions(+), 19 deletions(-) create mode 100644 controllers/statistics.go create mode 100644 models/analyze.go diff --git a/controllers/ai.go b/controllers/ai.go index 719e726..3a1dbfb 100644 --- a/controllers/ai.go +++ b/controllers/ai.go @@ -17,7 +17,9 @@ import ( "mime/multipart" "net/http" "os" + "strconv" "strings" + "time" "gorm.io/gorm" ) @@ -235,8 +237,16 @@ func buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, st 请以专业体育教师的视角,提供详细的数据分析和教学建议。`, teachingPlanContent, heartRateContent) } +type aiAnalysisResult struct { + Content string + InputTokens int + OutputTokens int + InputSizeBytes int + OutputSizeBytes int +} + // callAIForAnalysis 调用大模型进行分析 -func callAIForAnalysis(prompt string) (string, error) { +func callAIForAnalysis(prompt string) (*aiAnalysisResult, error) { sizeInBytes := len(prompt) sizeInKB := float64(sizeInBytes) / 1024.0 @@ -244,7 +254,7 @@ func callAIForAnalysis(prompt string) (string, error) { log.Printf("=== 发送给 AI 的内容大小: %.2f KB (%d 字节) ===", sizeInKB, sizeInBytes) baseURL, apiKey, model, err := config.GetAIConfig() if err != nil { - return "", err + return nil, err } clientConfig := openai.DefaultConfig(apiKey) @@ -267,14 +277,21 @@ func callAIForAnalysis(prompt string) (string, error) { }, ) if err != nil { - return "", fmt.Errorf("API call failed: %w", err) + return nil, fmt.Errorf("API call failed: %w", err) } if len(resp.Choices) == 0 { - return "", fmt.Errorf("no choices returned from API") + return nil, fmt.Errorf("no choices returned from API") } - return resp.Choices[0].Message.Content, nil + content := resp.Choices[0].Message.Content + return &aiAnalysisResult{ + Content: content, + InputTokens: resp.Usage.PromptTokens, + OutputTokens: resp.Usage.CompletionTokens, + InputSizeBytes: len(prompt), + OutputSizeBytes: len(content), + }, nil } // AnalyzeByAI Gin 控制器方法 @@ -292,6 +309,7 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { stepFiles := form.File["step_data"] analysisType := c.PostForm("analysis_type") teachingPlanSource := c.PostForm("teaching_plan_source") + regionIDStr := c.PostForm("regionid") if analysisType == "" { analysisType = analysisTypeHeartRateOnly } @@ -308,10 +326,12 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { return } + uploadTime := time.Now().UnixMilli() + // 3. 读取文件内容 // 注意:这里我们只取第一个上传的文件 heartRateFileHeader := csvFiles[0] - teachingPlanContent, err := resolveTeachingPlanContent(c, form, teachingPlanSource) + teachingPlanContent, teachingPlanSize, err := resolveTeachingPlanContent(c, form, teachingPlanSource) if err != nil { log.Printf("Error resolving teaching plan: %v", err) if errors.Is(err, gorm.ErrRecordNotFound) { @@ -330,8 +350,10 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { } stepContent := "" + var stepFileSize int64 = 0 if analysisType == analysisTypeHeartRateWithSteps { stepFileHeader := stepFiles[0] + stepFileSize = stepFileHeader.Size stepContent, err = readCSVContent(stepFileHeader) if err != nil { log.Printf("Error reading step file (%s): %v", stepFileHeader.Filename, err) @@ -340,47 +362,76 @@ func (tc *TrainingController) AnalyzeByAI(c *gin.Context) { } } + // 计算文件大小 + originalFileSize := heartRateFileHeader.Size + teachingPlanSize + stepFileSize + compressedContentSize := int64(len(heartRateContent)) + int64(len(teachingPlanContent)) + int64(len(stepContent)) + // 4. 构建 Prompt prompt := buildAnalysisPrompt(teachingPlanContent, heartRateContent, analysisType, stepContent) // 5. 调用 AI 分析 + startTime := time.Now() analysisResult, err := callAIForAnalysis(prompt) if err != nil { log.Printf("Error calling AI for analysis: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("AI analysis failed: %v", err)}) return } - //outputFile := ".md" - //ioutil.WriteFile(outputFile, []byte(analysisResult), 0644) + durationMs := time.Since(startTime).Milliseconds() - // 6. 返回结果 - // 方式一:返回 JSON 结构 + // 6. 保存分析记录 + var regionID *uint32 + if regionIDStr != "" { + if parsed, err := strconv.ParseUint(regionIDStr, 10, 32); err == nil { + id := uint32(parsed) + regionID = &id + } + } + + record := models.AIAnalysisRecord{ + RegionID: regionID, + SourceType: teachingPlanSource, + InputTokens: analysisResult.InputTokens, + OutputTokens: analysisResult.OutputTokens, + InputSizeBytes: analysisResult.InputSizeBytes, + OutputSizeBytes: analysisResult.OutputSizeBytes, + DurationMs: durationMs, + OriginalFileSize: originalFileSize, + CompressedContentSize: compressedContentSize, + UploadTime: uploadTime, + } + if err := config.DB.Create(&record).Error; err != nil { + log.Printf("Failed to save analysis record: %v", err) + } + + // 7. 返回结果 c.JSON(http.StatusOK, gin.H{ "status": "success", - "data": analysisResult, + "data": analysisResult.Content, }) - } -func resolveTeachingPlanContent(c *gin.Context, form *multipart.Form, source string) (string, error) { +func resolveTeachingPlanContent(c *gin.Context, form *multipart.Form, source string) (string, int64, error) { switch strings.ToLower(strings.TrimSpace(source)) { case "upload": docxFiles := form.File["teaching_plan"] if len(docxFiles) == 0 { - return "", fmt.Errorf("Missing required file: teaching_plan (.docx)") + return "", 0, fmt.Errorf("Missing required file: teaching_plan (.docx)") } - return readDocxContent(docxFiles[0]) + content, err := readDocxContent(docxFiles[0]) + return content, docxFiles[0].Size, err case "cloud": lessonPlanID := c.PostForm("lesson_plan_id") if strings.TrimSpace(lessonPlanID) == "" { - return "", fmt.Errorf("missing required field: lesson_plan_id") + return "", 0, fmt.Errorf("missing required field: lesson_plan_id") } var fileRecord models.AppFile if err := config.DB.Where("id = ? AND file_type = ?", lessonPlanID, models.AppFileTypeLessonPlan).First(&fileRecord).Error; err != nil { - return "", err + return "", 0, err } - return readDocxContentFromPath(fileRecord.FilePath) + content, err := readDocxContentFromPath(fileRecord.FilePath) + return content, fileRecord.FileSize, err default: - return "", fmt.Errorf("invalid teaching_plan_source, expected upload or cloud") + return "", 0, fmt.Errorf("invalid teaching_plan_source, expected upload or cloud") } } diff --git a/controllers/statistics.go b/controllers/statistics.go new file mode 100644 index 0000000..e6e9273 --- /dev/null +++ b/controllers/statistics.go @@ -0,0 +1,222 @@ +package controllers + +import ( + "errors" + "hr_receiver/config" + "hr_receiver/models" + "net/http" + "strconv" + "strings" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type StatisticsController struct { + DB *gorm.DB +} + +func NewStatisticsController() *StatisticsController { + return &StatisticsController{DB: config.DB} +} + +// --- 请求参数 --- + +type analysisRecordListParams struct { + PageNum int `form:"pageNum,default=1"` + PageSize int `form:"pageSize,default=10"` + RegionID uint32 `form:"regionId"` + StartTime int64 `form:"startTime"` + EndTime int64 `form:"endTime"` +} + +// --- 查询接口 --- + +func (sc *StatisticsController) ListAIAnalysisRecords(c *gin.Context) { + var params analysisRecordListParams + if err := c.ShouldBindQuery(¶ms); err != nil { + writeError(c, http.StatusBadRequest, err.Error()) + return + } + if params.PageNum < 1 { + params.PageNum = 1 + } + if params.PageSize < 1 || params.PageSize > 100 { + params.PageSize = 10 + } + offset := (params.PageNum - 1) * params.PageSize + + query := sc.DB.Model(&models.AIAnalysisRecord{}) + if params.RegionID > 0 { + query = query.Where("region_id = ?", params.RegionID) + } + if params.StartTime > 0 { + query = query.Where("upload_time >= ?", params.StartTime) + } + if params.EndTime > 0 { + query = query.Where("upload_time <= ?", params.EndTime) + } + + var total int64 + if err := query.Count(&total).Error; err != nil { + writeError(c, http.StatusInternalServerError, "failed to count records") + return + } + + var records []models.AIAnalysisRecord + if err := query.Order("created_at DESC").Offset(offset).Limit(params.PageSize).Find(&records).Error; err != nil { + writeError(c, http.StatusInternalServerError, "failed to query records") + return + } + + writeSuccess(c, http.StatusOK, "query success", gin.H{ + "list": records, + "pagination": gin.H{ + "currentPage": params.PageNum, + "pageSize": params.PageSize, + "totalList": total, + "totalPage": int((total + int64(params.PageSize) - 1) / int64(params.PageSize)), + }, + }) +} + +// --- 删除接口 --- + +func (sc *StatisticsController) DeleteAIAnalysisRecord(c *gin.Context) { + id := strings.TrimSpace(c.Param("id")) + if id == "" { + writeError(c, http.StatusBadRequest, "id is required") + return + } + + var record models.AIAnalysisRecord + if err := sc.DB.First(&record, id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + writeError(c, http.StatusNotFound, "record not found") + return + } + writeError(c, http.StatusInternalServerError, "failed to query record") + return + } + + if err := sc.DB.Delete(&record).Error; err != nil { + writeError(c, http.StatusInternalServerError, "failed to delete record") + return + } + + writeSuccess(c, http.StatusOK, "delete success", nil) +} + +// --- 统计接口 --- + +type regionStatisticsItem struct { + RegionID uint32 `json:"regionId"` + Count int64 `json:"count"` + TotalInputTokens int64 `json:"totalInputTokens"` + TotalOutputTokens int64 `json:"totalOutputTokens"` + TotalInputSizeBytes int64 `json:"totalInputSizeBytes"` + TotalOutputSizeBytes int64 `json:"totalOutputSizeBytes"` + TotalDurationMs int64 `json:"totalDurationMs"` + AvgDurationMs float64 `json:"avgDurationMs"` + TotalOriginalFileSize int64 `json:"totalOriginalFileSize"` + TotalCompressedSize int64 `json:"totalCompressedSize"` +} + +func (sc *StatisticsController) StatisticsByRegion(c *gin.Context) { + regionIDStr := c.Query("regionId") + startTimeStr := c.Query("startTime") + endTimeStr := c.Query("endTime") + + query := sc.DB.Model(&models.AIAnalysisRecord{}) + if regionIDStr != "" { + if regionID, err := strconv.ParseUint(regionIDStr, 10, 32); err == nil { + query = query.Where("region_id = ?", uint32(regionID)) + } + } + if startTimeStr != "" { + if startTime, err := strconv.ParseInt(startTimeStr, 10, 64); err == nil { + query = query.Where("upload_time >= ?", startTime) + } + } + if endTimeStr != "" { + if endTime, err := strconv.ParseInt(endTimeStr, 10, 64); err == nil { + query = query.Where("upload_time <= ?", endTime) + } + } + + type rawStats struct { + RegionID *uint32 + Count int64 + TotalInputTokens int64 + TotalOutputTokens int64 + TotalInputSizeBytes int64 + TotalOutputSizeBytes int64 + TotalDurationMs int64 + TotalOriginalFileSize int64 + TotalCompressedSize int64 + } + + var rawResults []rawStats + err := query.Select(` + region_id, + COUNT(*) as count, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(input_size_bytes), 0) as total_input_size_bytes, + COALESCE(SUM(output_size_bytes), 0) as total_output_size_bytes, + COALESCE(SUM(duration_ms), 0) as total_duration_ms, + COALESCE(SUM(original_file_size), 0) as total_original_file_size, + COALESCE(SUM(compressed_content_size), 0) as total_compressed_size + `).Group("region_id").Scan(&rawResults).Error + + if err != nil { + writeError(c, http.StatusInternalServerError, "failed to query statistics") + return + } + + overall := regionStatisticsItem{} + regions := make(map[string]regionStatisticsItem, len(rawResults)) + + for _, r := range rawResults { + regionID := uint32(0) + if r.RegionID != nil { + regionID = *r.RegionID + } + avgDuration := float64(0) + if r.Count > 0 { + avgDuration = float64(r.TotalDurationMs) / float64(r.Count) + } + item := regionStatisticsItem{ + RegionID: regionID, + Count: r.Count, + TotalInputTokens: r.TotalInputTokens, + TotalOutputTokens: r.TotalOutputTokens, + TotalInputSizeBytes: r.TotalInputSizeBytes, + TotalOutputSizeBytes: r.TotalOutputSizeBytes, + TotalDurationMs: r.TotalDurationMs, + AvgDurationMs: avgDuration, + TotalOriginalFileSize: r.TotalOriginalFileSize, + TotalCompressedSize: r.TotalCompressedSize, + } + + regions[strconv.FormatUint(uint64(regionID), 10)] = item + + overall.Count += r.Count + overall.TotalInputTokens += r.TotalInputTokens + overall.TotalOutputTokens += r.TotalOutputTokens + overall.TotalInputSizeBytes += r.TotalInputSizeBytes + overall.TotalOutputSizeBytes += r.TotalOutputSizeBytes + overall.TotalDurationMs += r.TotalDurationMs + overall.TotalOriginalFileSize += r.TotalOriginalFileSize + overall.TotalCompressedSize += r.TotalCompressedSize + } + + if overall.Count > 0 { + overall.AvgDurationMs = float64(overall.TotalDurationMs) / float64(overall.Count) + } + + writeSuccess(c, http.StatusOK, "query success", gin.H{ + "overall": overall, + "regions": regions, + }) +} diff --git a/main.go b/main.go index e07b91b..36c2e9d 100644 --- a/main.go +++ b/main.go @@ -37,6 +37,7 @@ func main() { &models.MqttGatewayStatusRecord{}, &models.MqttTrainingSessionRecord{}, &models.Gateway{}, + &models.AIAnalysisRecord{}, ) if err := models.BackfillLegacyUserPermissions(config.DB); err != nil { log.Printf("legacy user permission backfill failed: %v", err) diff --git a/models/analyze.go b/models/analyze.go new file mode 100644 index 0000000..3830f60 --- /dev/null +++ b/models/analyze.go @@ -0,0 +1,17 @@ +package models + +import "gorm.io/gorm" + +type AIAnalysisRecord struct { + gorm.Model + RegionID *uint32 `gorm:"index" json:"regionId"` + SourceType string `gorm:"size:32" json:"sourceType"` + InputTokens int `json:"inputTokens"` + OutputTokens int `json:"outputTokens"` + InputSizeBytes int `json:"inputSizeBytes"` + OutputSizeBytes int `json:"outputSizeBytes"` + DurationMs int64 `json:"durationMs"` + OriginalFileSize int64 `json:"originalFileSize"` + CompressedContentSize int64 `json:"compressedContentSize"` + UploadTime int64 `json:"uploadTime"` +} diff --git a/routes/routes.go b/routes/routes.go index 3707171..0d4842d 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -18,6 +18,7 @@ func SetupRouter() *gin.Engine { userAdminController := controllers.NewUserAdminController() gatewayController := controllers.NewGatewayAdminController() systemDebugController := controllers.NewSystemDebugController() + statisticsController := controllers.NewStatisticsController() deviceTokenHandler := func(c *gin.Context) { clientSecret := c.GetHeader("X-API-Key") if clientSecret != middleware.ApiSecret { @@ -85,6 +86,10 @@ func SetupRouter() *gin.Engine { admin.GET("/system-debug/mqtt/status", systemDebugController.MqttStatus) admin.POST("/system-debug/mqtt/start", systemDebugController.StartMqtt) admin.POST("/system-debug/mqtt/stop", systemDebugController.StopMqtt) + + admin.GET("/statistics/ai-analysis-records", statisticsController.ListAIAnalysisRecords) + admin.DELETE("/statistics/ai-analysis-records/:id", statisticsController.DeleteAIAnalysisRecord) + admin.GET("/statistics/ai-analysis", statisticsController.StatisticsByRegion) } v1.GET("/admin/system-debug/mqtt/ws", systemDebugController.MqttWebSocket)