diff --git a/controllers/train.go b/controllers/train.go index 19df33b..aabb410 100644 --- a/controllers/train.go +++ b/controllers/train.go @@ -212,7 +212,12 @@ func convertToCurvePoints(x, y []float64) []CurvePoint { } func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error { - startTime := record.StartTime + var startTime int64 + if record.TestTime > 0 { + startTime = record.TestTime + } else { + startTime = record.StartTime + } // 获取所有唯一的beltID var beltIDs []uint @@ -231,7 +236,7 @@ func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainR // 曲线拟合 x := []float64{2, 4, 6} y := []float64{averages["2min"], averages["4min"], averages["6min"]} - a, _ := quadraticFit(x, y) + a, b, _ := quadraticFit(x, y) // 存储结果 analysis := models.BeltAnalysis{ @@ -242,6 +247,7 @@ func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainR Avg4min: averages["4min"], Avg6min: averages["6min"], CurveParamA: a, + CurveParamB: b, } if err := tx.Create(&analysis).Error; err != nil { return err @@ -305,14 +311,44 @@ func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string return averages, nil } -func quadraticFit(x []float64, y []float64) (float64, error) { - // 使用三点计算y=ax²+b的a值(x=[2,4,6]对应分钟) +//func quadraticFit(x []float64, y []float64) (float64, error) { +// // 使用三点计算y=ax²+b的a值(x=[2,4,6]对应分钟) +// if len(x) != 3 || len(y) != 3 { +// return 0, errors.New("需要三个点") +// } +// // 构造方程组矩阵(简化计算) +// a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0]) +// return a, nil +//} + +func quadraticFit(x []float64, y []float64) (float64, float64, error) { + // 校验输入长度 if len(x) != 3 || len(y) != 3 { - return 0, errors.New("需要三个点") + return 0, 0, errors.New("需要三个点") } - // 构造方程组矩阵(简化计算) - a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0]) - return a, nil + + // 计算各项累加值 + var sumX4, sumX2, sumY, sumX2Y float64 + for i := 0; i < 3; i++ { + xi := x[i] + xi2 := xi * xi + sumX4 += xi2 * xi2 // x^4累加 + sumX2 += xi2 // x^2累加 + sumY += y[i] // y累加 + sumX2Y += xi2 * y[i] // x²y累加 + } + + // 计算行列式 + determinant := sumX4*3 - sumX2*sumX2 + if determinant == 0 { + return 0, 0, errors.New("无解,行列式为零") + } + + // 计算系数 a 和 b + a := (sumX2Y*3 - sumY*sumX2) / determinant + b := (sumX4*sumY - sumX2*sumX2Y) / determinant + + return a, b, nil } type TimeRange struct { diff --git a/models/training.go b/models/training.go index 810b655..ed44a6b 100644 --- a/models/training.go +++ b/models/training.go @@ -28,6 +28,7 @@ type BeltAnalysis struct { Avg4min float64 `gorm:"type:double precision"` // 第4分钟平均心率 Avg6min float64 `gorm:"type:double precision"` // 第6分钟平均心率 CurveParamA float64 `gorm:"type:double precision"` // 拟合参数a值 + CurveParamB float64 `gorm:"type:double precision"` // 拟合参数a值 } // 中间计算结构(无需持久化) @@ -47,9 +48,10 @@ type HeartRate struct { // 对应Flutter的TrainRecord结构 type TrainRecord struct { gorm.Model - TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段 - StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳 - EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref) + TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段 + StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳 + TestTime int64 `gorm:"type:bigint" json:"testTime"` // 开始时间戳 + EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref) Name string `gorm:"size:100" json:"name"` RunType string `gorm:"size:100" json:"RunType"` MaxHeartRate int `gorm:"type:int" json:"maxHeartRate"`