feat: rank query

This commit is contained in:
2025-08-04 11:18:07 +08:00
parent b00767dfcd
commit 7a2a44e327

View File

@ -585,41 +585,47 @@ func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps
return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值 return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值
} }
// 存储回归结果到数据库 // 存储回归结果到数据库(支持多种回归类型)
func (tc *StepTrainingController) SaveRegressionResult(trainId uint, result models.RegressionResult) error { func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error {
result.TrainId = trainId return tc.DB.Transaction(func(tx *gorm.DB) error {
for i := range results {
results[i].TrainId = trainId
// 使用复合唯一约束确保每种回归类型只存储一条记录
err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"equation": results[i].Equation,
"slope": results[i].Slope,
"intercept": results[i].Intercept,
"log_a": results[i].LogA,
"log_b": results[i].LogB,
"quadratic_a": results[i].QuadraticA,
"quadratic_b": results[i].QuadraticB,
"quadratic_c": results[i].QuadraticC,
"r_squared": results[i].RSquared,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&results[i]).Error
return tc.DB.Clauses(clause.OnConflict{ if err != nil {
Columns: []clause.Column{{Name: "id"}}, return err
DoUpdates: clause.Assignments(map[string]interface{}{ }
"equation": result.Equation, }
"slope": result.Slope, return nil
"intercept": result.Intercept, })
"log_a": result.LogA,
"log_b": result.LogB,
"quadratic_a": result.QuadraticA,
"quadratic_b": result.QuadraticB,
"quadratic_c": result.QuadraticC,
"r_squared": result.RSquared,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&result).Error
} }
// 获取或计算回归结果 // 获取或计算回归结果(返回多种回归类型列表)
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) (models.RegressionResult, error) { func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) {
// 首先尝试从数据库获取 // 尝试从数据库获取所有类型的回归结果
var result models.RegressionResult var results []models.RegressionResult
err := tc.DB.Where("train_id = ?", trainId).First(&result).Error err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error
// 如果找到记录,直接返回 // 如果已存在三种类型的结果,直接返回
if err == nil { if err == nil && len(results) >= 3 {
return result, nil return results, nil
}
// 如果错误不是记录不存在,返回错误
if !errors.Is(err, gorm.ErrRecordNotFound) {
return models.RegressionResult{}, err
} }
// 查询训练记录及相关数据 // 查询训练记录及相关数据
@ -629,52 +635,59 @@ func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) (models
Preload("HeartRates", "heart_rate_type = ?", 1). Preload("HeartRates", "heart_rate_type = ?", 1).
Preload("StrideFreqs", "predict_value = ?", 1). Preload("StrideFreqs", "predict_value = ?", 1).
First(&record).Error; err != nil { First(&record).Error; err != nil {
return models.RegressionResult{}, err return nil, err
} }
// 计算心率平均值模仿Flutter的calculateSegmentAveragesByRealStep // 计算心率平均值
averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs) averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs)
if len(averages) == 0 { if len(averages) == 0 {
return models.RegressionResult{}, errors.New("无足够数据进行回归计算") return nil, errors.New("无足够数据进行回归计算")
} }
// 计算三种回归 // 创建三种回归类型的结果
result = models.RegressionResult{ results = make([]models.RegressionResult, 3)
TrainId: trainId,
}
// 线性回归 // 线性回归
linearRes := performLinearRegression(averages) linearRes := performLinearRegression(averages)
result.Slope = linearRes.Slope results[0] = models.RegressionResult{
result.Intercept = linearRes.Intercept RegressionType: models.LinearRegression,
result.RSquared = linearRes.RSquared TrainId: trainId,
result.Equation = "线性回归: " + linearRes.Equation Equation: linearRes.Equation,
Slope: linearRes.Slope,
Intercept: linearRes.Intercept,
RSquared: linearRes.RSquared,
}
// 对数回归 // 对数回归
logRes := performLogarithmicRegression(averages) logRes := performLogarithmicRegression(averages)
result.LogA = logRes.LogA results[1] = models.RegressionResult{
result.LogB = logRes.LogB RegressionType: models.LogarithmicRegression,
if result.Equation != "" { TrainId: trainId,
result.Equation += "\n" Equation: logRes.Equation,
LogA: logRes.LogA,
LogB: logRes.LogB,
RSquared: logRes.RSquared,
} }
result.Equation += "对数回归: " + logRes.Equation
// 二次回归 // 二次回归
quadRes := performQuadraticRegression(averages) quadRes := performQuadraticRegression(averages)
result.QuadraticA = quadRes.QuadraticA results[2] = models.RegressionResult{
result.QuadraticB = quadRes.QuadraticB RegressionType: models.QuadraticRegression,
result.QuadraticC = quadRes.QuadraticC TrainId: trainId,
if result.Equation != "" { Equation: quadRes.Equation,
result.Equation += "\n" QuadraticA: quadRes.QuadraticA,
QuadraticB: quadRes.QuadraticB,
QuadraticC: quadRes.QuadraticC,
RSquared: quadRes.RSquared,
} }
result.Equation += "二次回归: " + quadRes.Equation
// 保存计算结果到数据库 // 批量保存结果到数据库
if err := tc.SaveRegressionResult(trainId, result); err != nil { if err := tc.SaveRegressionResults(trainId, results); err != nil {
log.Printf("保存回归结果失败: %v", err) log.Printf("保存回归结果失败: %v", err)
return nil, err
} }
return result, nil return results, nil
} }
// 新增接口:获取回归结果 // 新增接口:获取回归结果