refactor: train.

This commit is contained in:
2025-06-25 08:53:08 +08:00
parent f4df1724b2
commit 8510d19baa
2 changed files with 49 additions and 11 deletions

View File

@ -212,7 +212,12 @@ func convertToCurvePoints(x, y []float64) []CurvePoint {
}
func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error {
startTime := record.StartTime
var startTime int64
if record.TestTime > 0 {
startTime = record.TestTime
} else {
startTime = record.StartTime
}
// 获取所有唯一的beltID
var beltIDs []uint
@ -231,7 +236,7 @@ func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainR
// 曲线拟合
x := []float64{2, 4, 6}
y := []float64{averages["2min"], averages["4min"], averages["6min"]}
a, _ := quadraticFit(x, y)
a, b, _ := quadraticFit(x, y)
// 存储结果
analysis := models.BeltAnalysis{
@ -242,6 +247,7 @@ func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainR
Avg4min: averages["4min"],
Avg6min: averages["6min"],
CurveParamA: a,
CurveParamB: b,
}
if err := tx.Create(&analysis).Error; err != nil {
return err
@ -305,14 +311,44 @@ func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string
return averages, nil
}
func quadraticFit(x []float64, y []float64) (float64, error) {
// 使用三点计算y=ax²+b的a值x=[2,4,6]对应分钟)
//func quadraticFit(x []float64, y []float64) (float64, error) {
// // 使用三点计算y=ax²+b的a值x=[2,4,6]对应分钟)
// if len(x) != 3 || len(y) != 3 {
// return 0, errors.New("需要三个点")
// }
// // 构造方程组矩阵(简化计算)
// a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
// return a, nil
//}
func quadraticFit(x []float64, y []float64) (float64, float64, error) {
// 校验输入长度
if len(x) != 3 || len(y) != 3 {
return 0, errors.New("需要三个点")
return 0, 0, errors.New("需要三个点")
}
// 构造方程组矩阵(简化计算)
a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
return a, nil
// 计算各项累加值
var sumX4, sumX2, sumY, sumX2Y float64
for i := 0; i < 3; i++ {
xi := x[i]
xi2 := xi * xi
sumX4 += xi2 * xi2 // x^4累加
sumX2 += xi2 // x^2累加
sumY += y[i] // y累加
sumX2Y += xi2 * y[i] // x²y累加
}
// 计算行列式
determinant := sumX4*3 - sumX2*sumX2
if determinant == 0 {
return 0, 0, errors.New("无解,行列式为零")
}
// 计算系数 a 和 b
a := (sumX2Y*3 - sumY*sumX2) / determinant
b := (sumX4*sumY - sumX2*sumX2Y) / determinant
return a, b, nil
}
type TimeRange struct {

View File

@ -28,6 +28,7 @@ type BeltAnalysis struct {
Avg4min float64 `gorm:"type:double precision"` // 第4分钟平均心率
Avg6min float64 `gorm:"type:double precision"` // 第6分钟平均心率
CurveParamA float64 `gorm:"type:double precision"` // 拟合参数a值
CurveParamB float64 `gorm:"type:double precision"` // 拟合参数a值
}
// 中间计算结构(无需持久化)
@ -47,9 +48,10 @@ type HeartRate struct {
// 对应Flutter的TrainRecord结构
type TrainRecord struct {
gorm.Model
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
TestTime int64 `gorm:"type:bigint" json:"testTime"` // 开始时间戳
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
Name string `gorm:"size:100" json:"name"`
RunType string `gorm:"size:100" json:"RunType"`
MaxHeartRate int `gorm:"type:int" json:"maxHeartRate"`