feat: mock training.
This commit is contained in:
+176
-54
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
@@ -9,88 +10,92 @@ import (
|
||||
"hr_receiver/models"
|
||||
)
|
||||
|
||||
const (
|
||||
aiAnalysisRecordCount = 100
|
||||
trainingSessionRecordCount = 80
|
||||
insertBatchSize = 50
|
||||
)
|
||||
|
||||
var (
|
||||
regionIDs = []uint32{1, 3}
|
||||
sourceTypes = []string{"upload", "cloud"}
|
||||
appNames = []string{"HeartRate Teacher", "HeartRate Console", "HeartRate Admin"}
|
||||
)
|
||||
|
||||
func main() {
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
config.InitConfig()
|
||||
config.ConnectDB()
|
||||
|
||||
// 生成100条测试数据
|
||||
count := 100
|
||||
records := make([]models.AIAnalysisRecord, 0, count)
|
||||
//aiRecords := generateAIAnalysisRecords(rng, aiAnalysisRecordCount)
|
||||
//if err := config.DB.CreateInBatches(aiRecords, insertBatchSize).Error; err != nil {
|
||||
// panic("failed to insert AI analysis mock data: " + err.Error())
|
||||
//}
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
records = append(records, generateRecord())
|
||||
trainingRecords := generateTrainingSessionRecords(rng, trainingSessionRecordCount)
|
||||
if err := config.DB.CreateInBatches(trainingRecords, insertBatchSize).Error; err != nil {
|
||||
panic("failed to insert training session mock data: " + err.Error())
|
||||
}
|
||||
|
||||
if err := config.DB.CreateInBatches(records, 50).Error; err != nil {
|
||||
panic("failed to insert mock data: " + err.Error())
|
||||
}
|
||||
|
||||
fmt.Printf("成功插入 %d 条 AI 分析记录\n", count)
|
||||
//fmt.Printf("成功插入 %d 条 AI 分析记录\n", len(aiRecords))
|
||||
fmt.Printf("成功插入 %d 条训练会话记录\n", len(trainingRecords))
|
||||
}
|
||||
|
||||
func generateRecord() models.AIAnalysisRecord {
|
||||
// regionID 为 1 或 3
|
||||
regionID := uint32(1)
|
||||
if rand.Intn(2) == 1 {
|
||||
regionID = 3
|
||||
func generateAIAnalysisRecords(rng *rand.Rand, count int) []models.AIAnalysisRecord {
|
||||
records := make([]models.AIAnalysisRecord, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
records = append(records, generateAIAnalysisRecord(rng))
|
||||
}
|
||||
return records
|
||||
}
|
||||
|
||||
// sourceType: upload 或 cloud
|
||||
sourceType := "upload"
|
||||
if rand.Intn(2) == 1 {
|
||||
sourceType = "cloud"
|
||||
}
|
||||
func generateAIAnalysisRecord(rng *rand.Rand) models.AIAnalysisRecord {
|
||||
regionID := pickRegionID(rng)
|
||||
sourceType := pickOne(rng, sourceTypes)
|
||||
|
||||
// docx 教案原始文件大小: 50KB ~ 500KB
|
||||
docxSize := int64(rand.Intn(451*1024) + 50*1024)
|
||||
docxSize := int64(rng.Intn(451*1024) + 50*1024)
|
||||
csvSize := int64(rng.Intn(20*1024) + 70*1024)
|
||||
|
||||
// 心率 csv 原始文件大小: 约 80KB (70KB ~ 90KB)
|
||||
csvSize := int64(rand.Intn(20*1024) + 70*1024)
|
||||
|
||||
// 步数 csv 原始文件大小: 约 20KB ~ 40KB (heart_rate_with_steps 时才有)
|
||||
var stepCsvSize int64
|
||||
analysisType := analysisType()
|
||||
analysisType := randomAnalysisType(rng)
|
||||
var stepCSVSize int64
|
||||
if analysisType == "heart_rate_with_steps" {
|
||||
stepCsvSize = int64(rand.Intn(20*1024) + 20*1024)
|
||||
stepCSVSize = int64(rng.Intn(20*1024) + 20*1024)
|
||||
}
|
||||
|
||||
originalFileSize := docxSize + csvSize + stepCsvSize
|
||||
originalFileSize := docxSize + csvSize + stepCSVSize
|
||||
|
||||
// 压缩后内容大小: csv 每4行保留1行,大约压缩为 25% + 表头;docx 提取文本后大约 30%~60%
|
||||
compressedDocx := int64(float64(docxSize) * (0.3 + rand.Float64()*0.3))
|
||||
compressedCsv := int64(float64(csvSize) * (0.22 + rand.Float64()*0.08)) // ~22%-30%
|
||||
var compressedStepCsv int64
|
||||
if stepCsvSize > 0 {
|
||||
compressedStepCsv = int64(float64(stepCsvSize) * (0.22 + rand.Float64()*0.08))
|
||||
compressedDocx := int64(float64(docxSize) * (0.3 + rng.Float64()*0.3))
|
||||
compressedCSV := int64(float64(csvSize) * (0.22 + rng.Float64()*0.08))
|
||||
var compressedStepCSV int64
|
||||
if stepCSVSize > 0 {
|
||||
compressedStepCSV = int64(float64(stepCSVSize) * (0.22 + rng.Float64()*0.08))
|
||||
}
|
||||
compressedContentSize := compressedDocx + compressedCsv + compressedStepCsv
|
||||
compressedContentSize := compressedDocx + compressedCSV + compressedStepCSV
|
||||
|
||||
// prompt 大小 = 压缩后内容 + 提示词模板 (~1.5KB)
|
||||
promptTemplateSize := 1500 + rand.Intn(500)
|
||||
promptTemplateSize := 1500 + rng.Intn(500)
|
||||
inputSizeBytes := int(compressedContentSize) + promptTemplateSize
|
||||
outputSizeBytes := rng.Intn(22*1024) + 3*1024
|
||||
|
||||
// AI 输出大小: 3KB ~ 25KB (分析报告)
|
||||
outputSizeBytes := rand.Intn(22*1024) + 3*1024
|
||||
inputTokens := inputSizeBytes / (3 + rng.Intn(2))
|
||||
outputTokens := outputSizeBytes / (3 + rng.Intn(2))
|
||||
|
||||
// token 估算: 中文混合场景,平均约 3.5 字节/token
|
||||
inputTokens := inputSizeBytes / (3 + rand.Intn(2))
|
||||
outputTokens := outputSizeBytes / (3 + rand.Intn(2))
|
||||
|
||||
// 分析时长: 主要和输出 token 数量相关,1分钟以内
|
||||
// 基础延迟 500ms + 每token约 15~40ms
|
||||
tokenLatency := int64(15 + rand.Intn(26))
|
||||
tokenLatency := int64(15 + rng.Intn(26))
|
||||
durationMs := 500 + int64(outputTokens)*tokenLatency
|
||||
if durationMs > 60000 {
|
||||
durationMs = 60000 - int64(rand.Intn(5000))
|
||||
durationMs = 60000 - int64(rng.Intn(5000))
|
||||
}
|
||||
|
||||
// 上传时间: 最近 90 天内随机
|
||||
uploadTime := time.Now().Add(-time.Duration(rand.Intn(90*24)) * time.Hour).Add(-time.Duration(rand.Intn(60)) * time.Minute).UnixMilli()
|
||||
uploadTime := randomRecentMillis(rng, 90)
|
||||
totalCost, costJSON := buildMockCost(inputTokens, outputTokens)
|
||||
|
||||
return models.AIAnalysisRecord{
|
||||
RegionID: ®ionID,
|
||||
SourceType: sourceType,
|
||||
AnalysisType: analysisType,
|
||||
AnalysisResult: "mock analysis result",
|
||||
CostJSON: costJSON,
|
||||
TotalCost: totalCost,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
InputSizeBytes: inputSizeBytes,
|
||||
@@ -102,10 +107,127 @@ func generateRecord() models.AIAnalysisRecord {
|
||||
}
|
||||
}
|
||||
|
||||
func analysisType() string {
|
||||
// 约 30% 的带步数分析
|
||||
if rand.Intn(100) < 30 {
|
||||
func generateTrainingSessionRecords(rng *rand.Rand, count int) []models.MqttTrainingSessionRecord {
|
||||
records := make([]models.MqttTrainingSessionRecord, 0, count)
|
||||
sessionCount := count / 2
|
||||
if sessionCount == 0 {
|
||||
sessionCount = 1
|
||||
}
|
||||
|
||||
for i := 0; i < sessionCount; i++ {
|
||||
sessionRecords := generateTrainingSessionPair(rng, i)
|
||||
records = append(records, sessionRecords...)
|
||||
if len(records) >= count {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(records) > count {
|
||||
records = records[:count]
|
||||
}
|
||||
return records
|
||||
}
|
||||
|
||||
func generateTrainingSessionPair(rng *rand.Rand, index int) []models.MqttTrainingSessionRecord {
|
||||
regionID := pickRegionID(rng)
|
||||
appName := pickOne(rng, appNames)
|
||||
testID := fmt.Sprintf("mock-test-%04d", index+1)
|
||||
startedAt := randomRecentMillis(rng, 60)
|
||||
publishedAt := startedAt + int64(rng.Intn(30_000))
|
||||
|
||||
startRecord := models.MqttTrainingSessionRecord{
|
||||
Identifier: fmt.Sprintf("mock-heartrate-%d-%s", regionID, testID),
|
||||
Topic: fmt.Sprintf("/whgw/v2/region/test/%d", regionID),
|
||||
TestID: testID,
|
||||
EventType: "start_test",
|
||||
RegionID: regionID,
|
||||
FlavorType: "heartrate",
|
||||
RawFlavor: "heartrate",
|
||||
AppName: appName,
|
||||
StartedAt: int64Ptr(startedAt),
|
||||
PublishedAt: publishedAt,
|
||||
ReceivedAt: publishedAt + int64(rng.Intn(3000)),
|
||||
RawPayload: buildTrainingPayloadJSON(testID, regionID, appName, "start_test", publishedAt),
|
||||
}
|
||||
|
||||
records := []models.MqttTrainingSessionRecord{startRecord}
|
||||
|
||||
// 保留少量“开始未结束”的样本,用于测试进行中统计。
|
||||
if rng.Intn(100) < 15 {
|
||||
return records
|
||||
}
|
||||
|
||||
durationMs := int64((5+rng.Intn(41))*60*1000 + rng.Intn(60_000))
|
||||
endedAt := startedAt + durationMs
|
||||
stopPublishedAt := endedAt + int64(rng.Intn(20_000))
|
||||
stopRecord := models.MqttTrainingSessionRecord{
|
||||
Identifier: fmt.Sprintf("mock-heartrate-%d-%s-stop", regionID, testID),
|
||||
Topic: fmt.Sprintf("/whgw/v2/region/test/%d", regionID),
|
||||
TestID: testID,
|
||||
EventType: "stop_test",
|
||||
RegionID: regionID,
|
||||
FlavorType: "heartrate",
|
||||
RawFlavor: "heartrate",
|
||||
AppName: appName,
|
||||
StartedAt: int64Ptr(startedAt),
|
||||
EndedAt: int64Ptr(endedAt),
|
||||
PublishedAt: stopPublishedAt,
|
||||
ReceivedAt: stopPublishedAt + int64(rng.Intn(3000)),
|
||||
RawPayload: buildTrainingPayloadJSON(testID, regionID, appName, "stop_test", stopPublishedAt),
|
||||
}
|
||||
|
||||
return append(records, stopRecord)
|
||||
}
|
||||
|
||||
func buildMockCost(inputTokens, outputTokens int) (float64, string) {
|
||||
inputCost := float64(inputTokens) * 0.000002
|
||||
outputCost := float64(outputTokens) * 0.000008
|
||||
totalCost := inputCost + outputCost
|
||||
|
||||
payload := map[string]float64{
|
||||
"inputCost": inputCost,
|
||||
"outputCost": outputCost,
|
||||
}
|
||||
data, _ := json.Marshal(payload)
|
||||
return totalCost, string(data)
|
||||
}
|
||||
|
||||
func buildTrainingPayloadJSON(testID string, regionID uint32, appName, eventType string, publishedAt int64) string {
|
||||
payload := map[string]string{
|
||||
"appName": appName,
|
||||
"eventType": eventType,
|
||||
"flavor": "heartrate",
|
||||
"regionId": fmt.Sprintf("%d", regionID),
|
||||
"testId": testID,
|
||||
"timestamp": time.UnixMilli(publishedAt).UTC().Format("2006-01-02T15:04:05.000Z"),
|
||||
"type": "mqtt_test",
|
||||
}
|
||||
data, _ := json.Marshal(payload)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func randomAnalysisType(rng *rand.Rand) string {
|
||||
if rng.Intn(100) < 30 {
|
||||
return "heart_rate_with_steps"
|
||||
}
|
||||
return "heart_rate_only"
|
||||
}
|
||||
|
||||
func randomRecentMillis(rng *rand.Rand, days int) int64 {
|
||||
return time.Now().
|
||||
Add(-time.Duration(rng.Intn(days*24)) * time.Hour).
|
||||
Add(-time.Duration(rng.Intn(60)) * time.Minute).
|
||||
UnixMilli()
|
||||
}
|
||||
|
||||
func pickRegionID(rng *rand.Rand) uint32 {
|
||||
return regionIDs[rng.Intn(len(regionIDs))]
|
||||
}
|
||||
|
||||
func pickOne(rng *rand.Rand, values []string) string {
|
||||
return values[rng.Intn(len(values))]
|
||||
}
|
||||
|
||||
func int64Ptr(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user