package controllers import ( "errors" "github.com/gin-gonic/gin" "gorm.io/gorm" "gorm.io/gorm/clause" "hr_receiver/config" "hr_receiver/models" "net/http" ) type TrainingController struct { DB *gorm.DB } func NewTrainingController() *TrainingController { return &TrainingController{DB: config.DB} } // 接收训练记录 func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) { var record models.TrainRecord // 绑定并验证JSON数据 if err := c.ShouldBindJSON(&record); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // 使用事务保存数据[4](@ref) err := tc.DB.Transaction(func(tx *gorm.DB) error { // 保存主记录 if err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列 DoUpdates: clause.Assignments(map[string]interface{}{ "max_heart_rate": record.MaxHeartRate, "start_time": record.StartTime, "end_time": record.EndTime, "duration": record.Duration, "people_num": record.PeopleNum, "name": record.Name, "evaluation": record.Evaluation, }), }).Omit("HeartRates", "belts").Create(&record).Error; err != nil { return err } // 保存关联的心率数据 for i := range record.HeartRates { if err := tx.Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列 DoUpdates: clause.Assignments(map[string]interface{}{"value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}), }, ).Create(&record.HeartRates[i]).Error; err != nil { return err } } err := tc.heartRateAnalyze(tx, record) if err != nil { return err } return nil }) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusCreated, gin.H{ "message": "数据保存成功", "id": record.TrainId, }) } func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error { startTime := record.StartTime // 获取所有唯一的beltID var beltIDs []uint tx.Model(&models.HeartRate{}).Where("train_id = ?", record.TrainId). Select("DISTINCT belt_id").Pluck("belt_id", &beltIDs) // 对每个belt计算 for _, bid := range beltIDs { // 计算平均心率 ranges := getTimeRanges(startTime) averages, err := calculateAverages(tx, record.TrainId, bid, ranges) if err != nil { return err } // 曲线拟合 x := []float64{2, 4, 6} y := []float64{averages["2min"], averages["4min"], averages["6min"]} a, _ := quadraticFit(x, y) // 存储结果 analysis := models.BeltAnalysis{ TrainID: record.TrainId, RunType: record.RunType, BeltID: bid, Avg2min: averages["2min"], Avg4min: averages["4min"], Avg6min: averages["6min"], CurveParamA: a, } if err := tx.Create(&analysis).Error; err != nil { return err } } return nil } func ReceiveTrainingData(c *gin.Context) { var data models.TrainingData if err := c.ShouldBindJSON(&data); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": "Invalid request body: " + err.Error(), }) return } if result := config.DB.Create(&data); result.Error != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": "Failed to save data: " + result.Error.Error(), }) return } c.JSON(http.StatusCreated, gin.H{ "message": "Data saved successfully", "id": data.ID, }) } func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) { averages := make(map[string]float64) for key, tr := range ranges { var avg float64 // 使用GORM Raw SQL提高效率[6,10](@ref) err := tx.Raw(` SELECT AVG(value) FROM heart_rates WHERE train_id = ? AND belt_id = ? AND time BETWEEN ? AND ?`, trainID, beltID, tr.Start, tr.End, ).Scan(&avg).Error if err != nil { return nil, err } averages[key] = avg } return averages, nil } func quadraticFit(x []float64, y []float64) (float64, error) { // 使用三点计算y=ax²+b的a值(x=[2,4,6]对应分钟) if len(x) != 3 || len(y) != 3 { return 0, errors.New("需要三个点") } // 构造方程组矩阵(简化计算) a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0]) return a, nil } type TimeRange struct { Start int64 // 毫秒时间戳起点 End int64 // 毫秒时间戳终点 } func getTimeRanges(startTime int64) map[string]TimeRange { // 计算相对于训练开始时间的窗口 return map[string]TimeRange{ "2min": {Start: startTime + 120000, End: startTime + 240000}, // 第2分钟(120-240秒) "4min": {Start: startTime + 240000, End: startTime + 360000}, "6min": {Start: startTime + 360000, End: startTime + 480000}, } }