From 23d27b4b6e7e04724fa76859f3d853541d3b6b9e Mon Sep 17 00:00:00 2001 From: laoboli <1293528695@qq.com> Date: Thu, 30 Apr 2026 17:02:26 +0800 Subject: [PATCH] feat: mock training. --- controllers/statistics.go | 246 ++++++++++++++++++++++++++++++++++++++ mockdata/main.go | 230 ++++++++++++++++++++++++++--------- routes/routes.go | 1 + 3 files changed, 423 insertions(+), 54 deletions(-) diff --git a/controllers/statistics.go b/controllers/statistics.go index 8c95fa3..d148d78 100644 --- a/controllers/statistics.go +++ b/controllers/statistics.go @@ -351,6 +351,252 @@ func (sc *StatisticsController) StatisticsByRegion(c *gin.Context) { }) } +type trainingSessionRegionStatisticsItem struct { + RegionID uint32 `json:"regionId"` + KindergartenName string `json:"kindergartenName"` + Count int64 `json:"count"` + StartedCount int64 `json:"startedCount"` + EndedCount int64 `json:"endedCount"` + CompletedCount int64 `json:"completedCount"` + InProgressCount int64 `json:"inProgressCount"` + TotalDurationMs int64 `json:"totalDurationMs"` + AvgDurationMs float64 `json:"avgDurationMs"` + EventTypeCounts map[string]int64 `json:"eventTypeCounts"` + AppNameCounts map[string]int64 `json:"appNameCounts"` + FlavorTypeCounts map[string]int64 `json:"flavorTypeCounts"` + FirstPublishedAt *time.Time `json:"firstPublishedAt"` + LastPublishedAt *time.Time `json:"lastPublishedAt"` +} + +func (sc *StatisticsController) TrainingSessionStatisticsByRegion(c *gin.Context) { + regionIDStr := c.Query("regionId") + flavorType := strings.TrimSpace(c.Query("flavorType")) + startTimeStr := c.Query("startTime") + endTimeStr := c.Query("endTime") + + query := sc.DB.Model(&models.MqttTrainingSessionRecord{}) + if regionIDStr != "" { + if regionID, err := strconv.ParseUint(regionIDStr, 10, 32); err == nil { + query = query.Where("region_id = ?", uint32(regionID)) + } + } + if flavorType != "" { + query = query.Where("flavor_type = ?", flavorType) + } + if startTimeStr != "" { + if startTime, err := strconv.ParseInt(startTimeStr, 10, 64); err == nil { + query = query.Where("published_at >= ?", startTime) + } + } + if endTimeStr != "" { + if endTime, err := strconv.ParseInt(endTimeStr, 10, 64); err == nil { + query = query.Where("published_at <= ?", endTime) + } + } + + type rawTrainingStats struct { + RegionID *uint32 + Count int64 + StartedCount int64 + EndedCount int64 + CompletedCount int64 + InProgressCount int64 + TotalDurationMs int64 + FirstPublishedAt *int64 + LastPublishedAt *int64 + } + + var rawResults []rawTrainingStats + err := query.Select(` + region_id, + COUNT(*) as count, + COALESCE(SUM(CASE WHEN started_at IS NOT NULL THEN 1 ELSE 0 END), 0) as started_count, + COALESCE(SUM(CASE WHEN ended_at IS NOT NULL THEN 1 ELSE 0 END), 0) as ended_count, + COALESCE(SUM(CASE WHEN started_at IS NOT NULL AND ended_at IS NOT NULL THEN 1 ELSE 0 END), 0) as completed_count, + COALESCE(SUM(CASE WHEN started_at IS NOT NULL AND ended_at IS NULL THEN 1 ELSE 0 END), 0) as in_progress_count, + COALESCE(SUM(CASE WHEN started_at IS NOT NULL AND ended_at IS NOT NULL AND ended_at >= started_at THEN ended_at - started_at ELSE 0 END), 0) as total_duration_ms, + MIN(published_at) as first_published_at, + MAX(published_at) as last_published_at + `).Group("region_id").Scan(&rawResults).Error + if err != nil { + writeError(c, http.StatusInternalServerError, "failed to query training session statistics") + return + } + + type trainingEventTypeCount struct { + RegionID *uint32 + EventType string + Count int64 + } + var eventTypeResults []trainingEventTypeCount + if err := query.Select("region_id, event_type, COUNT(*) as count").Group("region_id, event_type").Scan(&eventTypeResults).Error; err != nil { + writeError(c, http.StatusInternalServerError, "failed to query training session event type statistics") + return + } + + type trainingAppNameCount struct { + RegionID *uint32 + AppName string + Count int64 + } + var appNameResults []trainingAppNameCount + if err := query.Select("region_id, app_name, COUNT(*) as count").Group("region_id, app_name").Scan(&appNameResults).Error; err != nil { + writeError(c, http.StatusInternalServerError, "failed to query training session app name statistics") + return + } + + type trainingFlavorTypeCount struct { + RegionID *uint32 + FlavorType string + Count int64 + } + var flavorTypeResults []trainingFlavorTypeCount + if err := query.Select("region_id, flavor_type, COUNT(*) as count").Group("region_id, flavor_type").Scan(&flavorTypeResults).Error; err != nil { + writeError(c, http.StatusInternalServerError, "failed to query training session flavor type statistics") + return + } + + eventTypeMap := make(map[uint32]map[string]int64) + for _, r := range eventTypeResults { + regionID := uint32(0) + if r.RegionID != nil { + regionID = *r.RegionID + } + if eventTypeMap[regionID] == nil { + eventTypeMap[regionID] = make(map[string]int64) + } + eventTypeMap[regionID][r.EventType] = r.Count + } + + appNameMap := make(map[uint32]map[string]int64) + for _, r := range appNameResults { + regionID := uint32(0) + if r.RegionID != nil { + regionID = *r.RegionID + } + if appNameMap[regionID] == nil { + appNameMap[regionID] = make(map[string]int64) + } + appNameMap[regionID][r.AppName] = r.Count + } + + flavorTypeMap := make(map[uint32]map[string]int64) + for _, r := range flavorTypeResults { + regionID := uint32(0) + if r.RegionID != nil { + regionID = *r.RegionID + } + if flavorTypeMap[regionID] == nil { + flavorTypeMap[regionID] = make(map[string]int64) + } + flavorTypeMap[regionID][r.FlavorType] = r.Count + } + + regionIDs := make([]uint32, 0, len(rawResults)) + for _, r := range rawResults { + if r.RegionID != nil && *r.RegionID > 0 { + regionIDs = append(regionIDs, *r.RegionID) + } + } + kindergartenMap := make(map[uint32]string) + if len(regionIDs) > 0 { + var kindergartens []models.Kindergarten + if err := sc.DB.Where("region_id IN ?", regionIDs).Find(&kindergartens).Error; err == nil { + for _, k := range kindergartens { + kindergartenMap[k.RegionID] = k.Name + } + } + } + + overall := trainingSessionRegionStatisticsItem{ + EventTypeCounts: make(map[string]int64), + AppNameCounts: make(map[string]int64), + FlavorTypeCounts: make(map[string]int64), + } + regions := make(map[string]trainingSessionRegionStatisticsItem, len(rawResults)) + + for _, r := range rawResults { + regionID := uint32(0) + if r.RegionID != nil { + regionID = *r.RegionID + } + avgDuration := float64(0) + if r.CompletedCount > 0 { + avgDuration = float64(r.TotalDurationMs) / float64(r.CompletedCount) + } + kgName := "" + if regionID > 0 { + kgName = kindergartenMap[regionID] + } + + var firstPublishedAt, lastPublishedAt *time.Time + if r.FirstPublishedAt != nil { + t := time.UnixMilli(*r.FirstPublishedAt) + firstPublishedAt = &t + } + if r.LastPublishedAt != nil { + t := time.UnixMilli(*r.LastPublishedAt) + lastPublishedAt = &t + } + + item := trainingSessionRegionStatisticsItem{ + RegionID: regionID, + KindergartenName: kgName, + Count: r.Count, + StartedCount: r.StartedCount, + EndedCount: r.EndedCount, + CompletedCount: r.CompletedCount, + InProgressCount: r.InProgressCount, + TotalDurationMs: r.TotalDurationMs, + AvgDurationMs: avgDuration, + EventTypeCounts: eventTypeMap[regionID], + AppNameCounts: appNameMap[regionID], + FlavorTypeCounts: flavorTypeMap[regionID], + FirstPublishedAt: firstPublishedAt, + LastPublishedAt: lastPublishedAt, + } + + regions[strconv.FormatUint(uint64(regionID), 10)] = item + + overall.Count += r.Count + overall.StartedCount += r.StartedCount + overall.EndedCount += r.EndedCount + overall.CompletedCount += r.CompletedCount + overall.InProgressCount += r.InProgressCount + overall.TotalDurationMs += r.TotalDurationMs + + if firstPublishedAt != nil { + if overall.FirstPublishedAt == nil || firstPublishedAt.Before(*overall.FirstPublishedAt) { + overall.FirstPublishedAt = firstPublishedAt + } + } + if lastPublishedAt != nil { + if overall.LastPublishedAt == nil || lastPublishedAt.After(*overall.LastPublishedAt) { + overall.LastPublishedAt = lastPublishedAt + } + } + } + + for _, r := range eventTypeResults { + overall.EventTypeCounts[r.EventType] += r.Count + } + for _, r := range appNameResults { + overall.AppNameCounts[r.AppName] += r.Count + } + for _, r := range flavorTypeResults { + overall.FlavorTypeCounts[r.FlavorType] += r.Count + } + + if overall.CompletedCount > 0 { + overall.AvgDurationMs = float64(overall.TotalDurationMs) / float64(overall.CompletedCount) + } + + writeSuccess(c, http.StatusOK, "query success", gin.H{ + "overall": overall, + "regions": regions, + }) +} + func (sc *StatisticsController) TimelineStatistics(c *gin.Context) { regionIDStr := c.Query("regionId") startTimeStr := c.Query("startTime") diff --git a/mockdata/main.go b/mockdata/main.go index 7bcbdeb..447187a 100644 --- a/mockdata/main.go +++ b/mockdata/main.go @@ -1,6 +1,7 @@ package main import ( + "encoding/json" "fmt" "math/rand" "time" @@ -9,88 +10,92 @@ import ( "hr_receiver/models" ) +const ( + aiAnalysisRecordCount = 100 + trainingSessionRecordCount = 80 + insertBatchSize = 50 +) + +var ( + regionIDs = []uint32{1, 3} + sourceTypes = []string{"upload", "cloud"} + appNames = []string{"HeartRate Teacher", "HeartRate Console", "HeartRate Admin"} +) + func main() { + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + config.InitConfig() config.ConnectDB() - // 生成100条测试数据 - count := 100 - records := make([]models.AIAnalysisRecord, 0, count) + //aiRecords := generateAIAnalysisRecords(rng, aiAnalysisRecordCount) + //if err := config.DB.CreateInBatches(aiRecords, insertBatchSize).Error; err != nil { + // panic("failed to insert AI analysis mock data: " + err.Error()) + //} - for i := 0; i < count; i++ { - records = append(records, generateRecord()) + trainingRecords := generateTrainingSessionRecords(rng, trainingSessionRecordCount) + if err := config.DB.CreateInBatches(trainingRecords, insertBatchSize).Error; err != nil { + panic("failed to insert training session mock data: " + err.Error()) } - if err := config.DB.CreateInBatches(records, 50).Error; err != nil { - panic("failed to insert mock data: " + err.Error()) - } - - fmt.Printf("成功插入 %d 条 AI 分析记录\n", count) + //fmt.Printf("成功插入 %d 条 AI 分析记录\n", len(aiRecords)) + fmt.Printf("成功插入 %d 条训练会话记录\n", len(trainingRecords)) } -func generateRecord() models.AIAnalysisRecord { - // regionID 为 1 或 3 - regionID := uint32(1) - if rand.Intn(2) == 1 { - regionID = 3 +func generateAIAnalysisRecords(rng *rand.Rand, count int) []models.AIAnalysisRecord { + records := make([]models.AIAnalysisRecord, 0, count) + for i := 0; i < count; i++ { + records = append(records, generateAIAnalysisRecord(rng)) } + return records +} - // sourceType: upload 或 cloud - sourceType := "upload" - if rand.Intn(2) == 1 { - sourceType = "cloud" - } +func generateAIAnalysisRecord(rng *rand.Rand) models.AIAnalysisRecord { + regionID := pickRegionID(rng) + sourceType := pickOne(rng, sourceTypes) - // docx 教案原始文件大小: 50KB ~ 500KB - docxSize := int64(rand.Intn(451*1024) + 50*1024) + docxSize := int64(rng.Intn(451*1024) + 50*1024) + csvSize := int64(rng.Intn(20*1024) + 70*1024) - // 心率 csv 原始文件大小: 约 80KB (70KB ~ 90KB) - csvSize := int64(rand.Intn(20*1024) + 70*1024) - - // 步数 csv 原始文件大小: 约 20KB ~ 40KB (heart_rate_with_steps 时才有) - var stepCsvSize int64 - analysisType := analysisType() + analysisType := randomAnalysisType(rng) + var stepCSVSize int64 if analysisType == "heart_rate_with_steps" { - stepCsvSize = int64(rand.Intn(20*1024) + 20*1024) + stepCSVSize = int64(rng.Intn(20*1024) + 20*1024) } - originalFileSize := docxSize + csvSize + stepCsvSize + originalFileSize := docxSize + csvSize + stepCSVSize - // 压缩后内容大小: csv 每4行保留1行,大约压缩为 25% + 表头;docx 提取文本后大约 30%~60% - compressedDocx := int64(float64(docxSize) * (0.3 + rand.Float64()*0.3)) - compressedCsv := int64(float64(csvSize) * (0.22 + rand.Float64()*0.08)) // ~22%-30% - var compressedStepCsv int64 - if stepCsvSize > 0 { - compressedStepCsv = int64(float64(stepCsvSize) * (0.22 + rand.Float64()*0.08)) + compressedDocx := int64(float64(docxSize) * (0.3 + rng.Float64()*0.3)) + compressedCSV := int64(float64(csvSize) * (0.22 + rng.Float64()*0.08)) + var compressedStepCSV int64 + if stepCSVSize > 0 { + compressedStepCSV = int64(float64(stepCSVSize) * (0.22 + rng.Float64()*0.08)) } - compressedContentSize := compressedDocx + compressedCsv + compressedStepCsv + compressedContentSize := compressedDocx + compressedCSV + compressedStepCSV - // prompt 大小 = 压缩后内容 + 提示词模板 (~1.5KB) - promptTemplateSize := 1500 + rand.Intn(500) + promptTemplateSize := 1500 + rng.Intn(500) inputSizeBytes := int(compressedContentSize) + promptTemplateSize + outputSizeBytes := rng.Intn(22*1024) + 3*1024 - // AI 输出大小: 3KB ~ 25KB (分析报告) - outputSizeBytes := rand.Intn(22*1024) + 3*1024 + inputTokens := inputSizeBytes / (3 + rng.Intn(2)) + outputTokens := outputSizeBytes / (3 + rng.Intn(2)) - // token 估算: 中文混合场景,平均约 3.5 字节/token - inputTokens := inputSizeBytes / (3 + rand.Intn(2)) - outputTokens := outputSizeBytes / (3 + rand.Intn(2)) - - // 分析时长: 主要和输出 token 数量相关,1分钟以内 - // 基础延迟 500ms + 每token约 15~40ms - tokenLatency := int64(15 + rand.Intn(26)) + tokenLatency := int64(15 + rng.Intn(26)) durationMs := 500 + int64(outputTokens)*tokenLatency if durationMs > 60000 { - durationMs = 60000 - int64(rand.Intn(5000)) + durationMs = 60000 - int64(rng.Intn(5000)) } - // 上传时间: 最近 90 天内随机 - uploadTime := time.Now().Add(-time.Duration(rand.Intn(90*24)) * time.Hour).Add(-time.Duration(rand.Intn(60)) * time.Minute).UnixMilli() + uploadTime := randomRecentMillis(rng, 90) + totalCost, costJSON := buildMockCost(inputTokens, outputTokens) return models.AIAnalysisRecord{ RegionID: ®ionID, SourceType: sourceType, AnalysisType: analysisType, + AnalysisResult: "mock analysis result", + CostJSON: costJSON, + TotalCost: totalCost, InputTokens: inputTokens, OutputTokens: outputTokens, InputSizeBytes: inputSizeBytes, @@ -102,10 +107,127 @@ func generateRecord() models.AIAnalysisRecord { } } -func analysisType() string { - // 约 30% 的带步数分析 - if rand.Intn(100) < 30 { +func generateTrainingSessionRecords(rng *rand.Rand, count int) []models.MqttTrainingSessionRecord { + records := make([]models.MqttTrainingSessionRecord, 0, count) + sessionCount := count / 2 + if sessionCount == 0 { + sessionCount = 1 + } + + for i := 0; i < sessionCount; i++ { + sessionRecords := generateTrainingSessionPair(rng, i) + records = append(records, sessionRecords...) + if len(records) >= count { + break + } + } + + if len(records) > count { + records = records[:count] + } + return records +} + +func generateTrainingSessionPair(rng *rand.Rand, index int) []models.MqttTrainingSessionRecord { + regionID := pickRegionID(rng) + appName := pickOne(rng, appNames) + testID := fmt.Sprintf("mock-test-%04d", index+1) + startedAt := randomRecentMillis(rng, 60) + publishedAt := startedAt + int64(rng.Intn(30_000)) + + startRecord := models.MqttTrainingSessionRecord{ + Identifier: fmt.Sprintf("mock-heartrate-%d-%s", regionID, testID), + Topic: fmt.Sprintf("/whgw/v2/region/test/%d", regionID), + TestID: testID, + EventType: "start_test", + RegionID: regionID, + FlavorType: "heartrate", + RawFlavor: "heartrate", + AppName: appName, + StartedAt: int64Ptr(startedAt), + PublishedAt: publishedAt, + ReceivedAt: publishedAt + int64(rng.Intn(3000)), + RawPayload: buildTrainingPayloadJSON(testID, regionID, appName, "start_test", publishedAt), + } + + records := []models.MqttTrainingSessionRecord{startRecord} + + // 保留少量“开始未结束”的样本,用于测试进行中统计。 + if rng.Intn(100) < 15 { + return records + } + + durationMs := int64((5+rng.Intn(41))*60*1000 + rng.Intn(60_000)) + endedAt := startedAt + durationMs + stopPublishedAt := endedAt + int64(rng.Intn(20_000)) + stopRecord := models.MqttTrainingSessionRecord{ + Identifier: fmt.Sprintf("mock-heartrate-%d-%s-stop", regionID, testID), + Topic: fmt.Sprintf("/whgw/v2/region/test/%d", regionID), + TestID: testID, + EventType: "stop_test", + RegionID: regionID, + FlavorType: "heartrate", + RawFlavor: "heartrate", + AppName: appName, + StartedAt: int64Ptr(startedAt), + EndedAt: int64Ptr(endedAt), + PublishedAt: stopPublishedAt, + ReceivedAt: stopPublishedAt + int64(rng.Intn(3000)), + RawPayload: buildTrainingPayloadJSON(testID, regionID, appName, "stop_test", stopPublishedAt), + } + + return append(records, stopRecord) +} + +func buildMockCost(inputTokens, outputTokens int) (float64, string) { + inputCost := float64(inputTokens) * 0.000002 + outputCost := float64(outputTokens) * 0.000008 + totalCost := inputCost + outputCost + + payload := map[string]float64{ + "inputCost": inputCost, + "outputCost": outputCost, + } + data, _ := json.Marshal(payload) + return totalCost, string(data) +} + +func buildTrainingPayloadJSON(testID string, regionID uint32, appName, eventType string, publishedAt int64) string { + payload := map[string]string{ + "appName": appName, + "eventType": eventType, + "flavor": "heartrate", + "regionId": fmt.Sprintf("%d", regionID), + "testId": testID, + "timestamp": time.UnixMilli(publishedAt).UTC().Format("2006-01-02T15:04:05.000Z"), + "type": "mqtt_test", + } + data, _ := json.Marshal(payload) + return string(data) +} + +func randomAnalysisType(rng *rand.Rand) string { + if rng.Intn(100) < 30 { return "heart_rate_with_steps" } return "heart_rate_only" } + +func randomRecentMillis(rng *rand.Rand, days int) int64 { + return time.Now(). + Add(-time.Duration(rng.Intn(days*24)) * time.Hour). + Add(-time.Duration(rng.Intn(60)) * time.Minute). + UnixMilli() +} + +func pickRegionID(rng *rand.Rand) uint32 { + return regionIDs[rng.Intn(len(regionIDs))] +} + +func pickOne(rng *rand.Rand, values []string) string { + return values[rng.Intn(len(values))] +} + +func int64Ptr(v int64) *int64 { + return &v +} diff --git a/routes/routes.go b/routes/routes.go index dd3ecbf..55ca237 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -91,6 +91,7 @@ func SetupRouter() *gin.Engine { admin.DELETE("/statistics/ai-analysis-records/:id", statisticsController.DeleteAIAnalysisRecord) admin.GET("/statistics/ai-analysis", statisticsController.StatisticsByRegion) admin.GET("/statistics/ai-analysis-timeline", statisticsController.TimelineStatistics) + admin.GET("/statistics/mqtt-training-sessions", statisticsController.TrainingSessionStatisticsByRegion) } v1.GET("/admin/system-debug/mqtt/ws", systemDebugController.MqttWebSocket)