Files
hr_data_analyzer/controllers/train.go
2025-04-01 10:13:38 +08:00

186 lines
4.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package controllers
import (
"errors"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"hr_receiver/config"
"hr_receiver/models"
"net/http"
)
type TrainingController struct {
DB *gorm.DB
}
func NewTrainingController() *TrainingController {
return &TrainingController{DB: config.DB}
}
// 接收训练记录
func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
var record models.TrainRecord
// 绑定并验证JSON数据
if err := c.ShouldBindJSON(&record); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 使用事务保存数据[4](@ref)
err := tc.DB.Transaction(func(tx *gorm.DB) error {
// 保存主记录
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{
"max_heart_rate": record.MaxHeartRate,
"start_time": record.StartTime,
"end_time": record.EndTime,
"duration": record.Duration,
"people_num": record.PeopleNum,
"name": record.Name,
"evaluation": record.Evaluation,
}),
}).Omit("HeartRates", "belts").Create(&record).Error; err != nil {
return err
}
// 保存关联的心率数据
for i := range record.HeartRates {
if err := tx.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
},
).Create(&record.HeartRates[i]).Error; err != nil {
return err
}
}
err := tc.heartRateAnalyze(tx, record)
if err != nil {
return err
}
return nil
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "数据保存成功",
"id": record.TrainId,
})
}
func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error {
startTime := record.StartTime
// 获取所有唯一的beltID
var beltIDs []uint
tx.Model(&models.HeartRate{}).Where("train_id = ?", record.TrainId).
Select("DISTINCT belt_id").Pluck("belt_id", &beltIDs)
// 对每个belt计算
for _, bid := range beltIDs {
// 计算平均心率
ranges := getTimeRanges(startTime)
averages, err := calculateAverages(tx, record.TrainId, bid, ranges)
if err != nil {
return err
}
// 曲线拟合
x := []float64{2, 4, 6}
y := []float64{averages["2min"], averages["4min"], averages["6min"]}
a, _ := quadraticFit(x, y)
// 存储结果
analysis := models.BeltAnalysis{
TrainID: record.TrainId,
RunType: record.RunType,
BeltID: bid,
Avg2min: averages["2min"],
Avg4min: averages["4min"],
Avg6min: averages["6min"],
CurveParamA: a,
}
if err := tx.Create(&analysis).Error; err != nil {
return err
}
}
return nil
}
func ReceiveTrainingData(c *gin.Context) {
var data models.TrainingData
if err := c.ShouldBindJSON(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid request body: " + err.Error(),
})
return
}
if result := config.DB.Create(&data); result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to save data: " + result.Error.Error(),
})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "Data saved successfully",
"id": data.ID,
})
}
func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) {
averages := make(map[string]float64)
for key, tr := range ranges {
var avg float64
// 使用GORM Raw SQL提高效率[6,10](@ref)
err := tx.Raw(`
SELECT AVG(value)
FROM heart_rates
WHERE train_id = ?
AND belt_id = ?
AND time BETWEEN ? AND ?`,
trainID, beltID, tr.Start, tr.End,
).Scan(&avg).Error
if err != nil {
return nil, err
}
averages[key] = avg
}
return averages, nil
}
func quadraticFit(x []float64, y []float64) (float64, error) {
// 使用三点计算y=ax²+b的a值x=[2,4,6]对应分钟)
if len(x) != 3 || len(y) != 3 {
return 0, errors.New("需要三个点")
}
// 构造方程组矩阵(简化计算)
a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
return a, nil
}
type TimeRange struct {
Start int64 // 毫秒时间戳起点
End int64 // 毫秒时间戳终点
}
func getTimeRanges(startTime int64) map[string]TimeRange {
// 计算相对于训练开始时间的窗口
return map[string]TimeRange{
"2min": {Start: startTime + 120000, End: startTime + 240000}, // 第2分钟120-240秒
"4min": {Start: startTime + 240000, End: startTime + 360000},
"6min": {Start: startTime + 360000, End: startTime + 480000},
}
}