package controllers import ( "gonum.org/v1/gonum/floats" "gonum.org/v1/gonum/stat" "gonum.org/v1/gonum/stat/distuv" ) import ( "errors" "github.com/gin-gonic/gin" "gorm.io/gorm" "gorm.io/gorm/clause" "hr_receiver/config" "hr_receiver/models" "math" "net/http" ) var analyzeRunTypes = []string{"6.5开始", "7开始", "8开始"} // 替换为你的具体值 func contains(s string) bool { for _, item := range analyzeRunTypes { if item == s { return true } } return false } type TrainingController struct { DB *gorm.DB } func NewTrainingController() *TrainingController { return &TrainingController{DB: config.DB} } // 接收训练记录 func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) { var record models.TrainRecord // 绑定并验证JSON数据 if err := c.ShouldBindJSON(&record); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // 使用事务保存数据[4](@ref) err := tc.DB.Transaction(func(tx *gorm.DB) error { // 保存主记录 if err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列 DoUpdates: clause.Assignments(map[string]interface{}{ "max_heart_rate": record.MaxHeartRate, "start_time": record.StartTime, "end_time": record.EndTime, "duration": record.Duration, "people_num": record.PeopleNum, "name": record.Name, "evaluation": record.Evaluation, }), }).Omit("HeartRates", "belts").Create(&record).Error; err != nil { return err } // 保存关联的心率数据 for i := range record.HeartRates { if err := tx.Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列 DoUpdates: clause.Assignments(map[string]interface{}{"value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}), }, ).Create(&record.HeartRates[i]).Error; err != nil { return err } } if contains(record.RunType) { err := tc.heartRateAnalyze(tx, record) if err != nil { return err } } return nil }) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusCreated, gin.H{ "message": "数据保存成功", "id": record.TrainId, }) } // analysis_response.go type AnalysisResponse struct { Status string `json:"status"` // 状态码 Message string `json:"message"` // 附加信息 Data struct { Mean float64 `json:"mean"` // 均值 StdDev float64 `json:"stdDev"` // 标准差 Histogram []HistoBin `json:"histogram"` // 直方图数据 Curve []CurvePoint `json:"curve"` // 正态曲线数据 } `json:"data"` } type HistoBin struct { BinStart float64 `json:"binStart"` // 区间起始值 BinEnd float64 `json:"binEnd"` // 区间结束值 Count int `json:"count"` // 该区间计数 } type CurvePoint struct { X float64 `json:"x"` // X坐标 Y float64 `json:"y"` // Y坐标 } // analysis_handler.go func (tc *TrainingController) HandleCurveAnalysis(c *gin.Context) { // 获取数据库连接(根据实际项目配置调整) // 1. 获取历史数据 aValues, err := collectCurveParams(tc.DB) if err != nil { c.JSON(500, gin.H{ "status": "error", "message": "数据查询失败: " + err.Error(), }) return } // 2. 检查数据有效性 if len(aValues) < 10 { // 至少需要10个样本 c.JSON(400, gin.H{ "status": "fail", "message": "数据量不足,至少需要10个样本", }) return } // 3. 计算统计量 mean, stddev := calculateStats(aValues) // 4. 生成直方图数据 histogram := calculateHistogram(aValues, 20) // 20个分箱 // 5. 生成正态曲线 x, y := generateNormalCurve(mean, stddev, 100) // 6. 构造响应 response := AnalysisResponse{ Status: "success", Message: "分析完成", Data: struct { Mean float64 `json:"mean"` StdDev float64 `json:"stdDev"` Histogram []HistoBin `json:"histogram"` Curve []CurvePoint `json:"curve"` }{ Mean: mean, StdDev: stddev, Histogram: histogram, Curve: convertToCurvePoints(x, y), }, } c.JSON(200, response) } // 直方图计算函数 func calculateHistogram(data []float64, bins int) []HistoBin { minV, maxV := floats.Min(data), floats.Max(data) binWidth := (maxV - minV) / float64(bins) counts := make([]int, bins) for _, v := range data { idx := int((v - minV) / binWidth) if idx == bins { // 处理最大值刚好等于maxV的情况 idx-- } counts[idx]++ } histogram := make([]HistoBin, bins) for i := 0; i < bins; i++ { start := minV + float64(i)*binWidth end := minV + float64(i+1)*binWidth histogram[i] = HistoBin{ BinStart: start, BinEnd: end, Count: counts[i], } } return histogram } // 转换曲线数据格式 func convertToCurvePoints(x, y []float64) []CurvePoint { points := make([]CurvePoint, len(x)) for i := range x { points[i] = CurvePoint{ X: x[i], Y: y[i], } } return points } func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error { var startTime int64 if record.TestTime > 0 { startTime = record.TestTime } else { startTime = record.StartTime } // 获取所有唯一的beltID var beltIDs []uint tx.Model(&models.HeartRate{}).Where("train_id = ?", record.TrainId). Select("DISTINCT belt_id").Pluck("belt_id", &beltIDs) // 对每个belt计算 for _, bid := range beltIDs { // 计算平均心率 ranges := getTimeRanges(startTime) averages, err := calculateAverages(tx, record.TrainId, bid, ranges) if err != nil { return err } // 曲线拟合 x := []float64{2, 4, 6} y := []float64{averages["2min"], averages["4min"], averages["6min"]} a, b, _ := quadraticFit(x, y) // 存储结果 analysis := models.BeltAnalysis{ TrainID: record.TrainId, RunType: record.RunType, BeltID: bid, Avg2min: averages["2min"], Avg4min: averages["4min"], Avg6min: averages["6min"], CurveParamA: a, CurveParamB: b, } if err := tx.Create(&analysis).Error; err != nil { return err } } return nil } func collectCurveParams(tx *gorm.DB) ([]float64, error) { var aValues []float64 // 查询所有记录的 CurveParamA 字段 err := tx.Model(&models.BeltAnalysis{}).Pluck("curve_param_a", &aValues).Error if err != nil { return nil, err } return aValues, nil } func calculateStats(data []float64) (mean, stddev float64) { mean = stat.Mean(data, nil) variance := stat.Variance(data, nil) stddev = math.Sqrt(variance) return } func generateNormalCurve(mean, stddev float64, numPoints int) (x, y []float64) { normal := distuv.Normal{ Mu: mean, Sigma: stddev, } minV := mean - 3*stddev // 从均值-3σ开始 maxV := mean + 3*stddev // 到均值+3σ结束 step := (maxV - minV) / float64(numPoints-1) for i := 0; i < numPoints; i++ { xi := minV + float64(i)*step yi := normal.Prob(xi) x = append(x, xi) y = append(y, yi) } return } func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) { averages := make(map[string]float64) for key, tr := range ranges { var avg float64 // 使用GORM Raw SQL提高效率[6,10](@ref) err := tx.Raw(` SELECT COALESCE(AVG(value), 0) AS avg -- 关键修复 FROM heart_rates WHERE train_id = ? AND belt_id = ? AND time BETWEEN ? AND ?`, trainID, beltID, tr.Start, tr.End, ).Scan(&avg).Error if err != nil { return nil, err } averages[key] = avg } return averages, nil } //func quadraticFit(x []float64, y []float64) (float64, error) { // // 使用三点计算y=ax²+b的a值(x=[2,4,6]对应分钟) // if len(x) != 3 || len(y) != 3 { // return 0, errors.New("需要三个点") // } // // 构造方程组矩阵(简化计算) // a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0]) // return a, nil //} func quadraticFit(x []float64, y []float64) (float64, float64, error) { // 校验输入长度 if len(x) != 3 || len(y) != 3 { return 0, 0, errors.New("需要三个点") } // 计算各项累加值 var sumX4, sumX2, sumY, sumX2Y float64 for i := 0; i < 3; i++ { xi := x[i] xi2 := xi * xi sumX4 += xi2 * xi2 // x^4累加 sumX2 += xi2 // x^2累加 sumY += y[i] // y累加 sumX2Y += xi2 * y[i] // x²y累加 } // 计算行列式 determinant := sumX4*3 - sumX2*sumX2 if determinant == 0 { return 0, 0, errors.New("无解,行列式为零") } // 计算系数 a 和 b a := (sumX2Y*3 - sumY*sumX2) / determinant b := (sumX4*sumY - sumX2*sumX2Y) / determinant return a, b, nil } type TimeRange struct { Start int64 // 毫秒时间戳起点 End int64 // 毫秒时间戳终点 } func getTimeRanges(startTime int64) map[string]TimeRange { // 计算相对于训练开始时间的窗口 return map[string]TimeRange{ "2min": {Start: startTime + 120000, End: startTime + 240000}, // 第2分钟(120-240秒) "4min": {Start: startTime + 240000, End: startTime + 360000}, "6min": {Start: startTime + 360000, End: startTime + 480000}, } }