feat: rank query
This commit is contained in:
@ -585,41 +585,47 @@ func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps
|
||||
return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值
|
||||
}
|
||||
|
||||
// 存储回归结果到数据库
|
||||
func (tc *StepTrainingController) SaveRegressionResult(trainId uint, result models.RegressionResult) error {
|
||||
result.TrainId = trainId
|
||||
|
||||
return tc.DB.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
// 存储回归结果到数据库(支持多种回归类型)
|
||||
func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error {
|
||||
return tc.DB.Transaction(func(tx *gorm.DB) error {
|
||||
for i := range results {
|
||||
results[i].TrainId = trainId
|
||||
// 使用复合唯一约束确保每种回归类型只存储一条记录
|
||||
err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "id"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"equation": result.Equation,
|
||||
"slope": result.Slope,
|
||||
"intercept": result.Intercept,
|
||||
"log_a": result.LogA,
|
||||
"log_b": result.LogB,
|
||||
"quadratic_a": result.QuadraticA,
|
||||
"quadratic_b": result.QuadraticB,
|
||||
"quadratic_c": result.QuadraticC,
|
||||
"r_squared": result.RSquared,
|
||||
"equation": results[i].Equation,
|
||||
"slope": results[i].Slope,
|
||||
"intercept": results[i].Intercept,
|
||||
"log_a": results[i].LogA,
|
||||
"log_b": results[i].LogB,
|
||||
"quadratic_a": results[i].QuadraticA,
|
||||
"quadratic_b": results[i].QuadraticB,
|
||||
"quadratic_c": results[i].QuadraticC,
|
||||
"r_squared": results[i].RSquared,
|
||||
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
}),
|
||||
}).Create(&result).Error
|
||||
}).Create(&results[i]).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// 获取或计算回归结果
|
||||
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) (models.RegressionResult, error) {
|
||||
// 首先尝试从数据库获取
|
||||
var result models.RegressionResult
|
||||
err := tc.DB.Where("train_id = ?", trainId).First(&result).Error
|
||||
// 获取或计算回归结果(返回多种回归类型列表)
|
||||
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) {
|
||||
// 尝试从数据库获取所有类型的回归结果
|
||||
var results []models.RegressionResult
|
||||
err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error
|
||||
|
||||
// 如果找到记录,直接返回
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 如果错误不是记录不存在,返回错误
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return models.RegressionResult{}, err
|
||||
// 如果已存在三种类型的结果,直接返回
|
||||
if err == nil && len(results) >= 3 {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// 查询训练记录及相关数据
|
||||
@ -629,52 +635,59 @@ func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) (models
|
||||
Preload("HeartRates", "heart_rate_type = ?", 1).
|
||||
Preload("StrideFreqs", "predict_value = ?", 1).
|
||||
First(&record).Error; err != nil {
|
||||
return models.RegressionResult{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 计算心率平均值(模仿Flutter的calculateSegmentAveragesByRealStep)
|
||||
// 计算心率平均值
|
||||
averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs)
|
||||
if len(averages) == 0 {
|
||||
return models.RegressionResult{}, errors.New("无足够数据进行回归计算")
|
||||
return nil, errors.New("无足够数据进行回归计算")
|
||||
}
|
||||
|
||||
// 计算三种回归
|
||||
result = models.RegressionResult{
|
||||
TrainId: trainId,
|
||||
}
|
||||
// 创建三种回归类型的结果
|
||||
results = make([]models.RegressionResult, 3)
|
||||
|
||||
// 线性回归
|
||||
linearRes := performLinearRegression(averages)
|
||||
result.Slope = linearRes.Slope
|
||||
result.Intercept = linearRes.Intercept
|
||||
result.RSquared = linearRes.RSquared
|
||||
result.Equation = "线性回归: " + linearRes.Equation
|
||||
results[0] = models.RegressionResult{
|
||||
RegressionType: models.LinearRegression,
|
||||
TrainId: trainId,
|
||||
Equation: linearRes.Equation,
|
||||
Slope: linearRes.Slope,
|
||||
Intercept: linearRes.Intercept,
|
||||
RSquared: linearRes.RSquared,
|
||||
}
|
||||
|
||||
// 对数回归
|
||||
logRes := performLogarithmicRegression(averages)
|
||||
result.LogA = logRes.LogA
|
||||
result.LogB = logRes.LogB
|
||||
if result.Equation != "" {
|
||||
result.Equation += "\n"
|
||||
results[1] = models.RegressionResult{
|
||||
RegressionType: models.LogarithmicRegression,
|
||||
TrainId: trainId,
|
||||
Equation: logRes.Equation,
|
||||
LogA: logRes.LogA,
|
||||
LogB: logRes.LogB,
|
||||
RSquared: logRes.RSquared,
|
||||
}
|
||||
result.Equation += "对数回归: " + logRes.Equation
|
||||
|
||||
// 二次回归
|
||||
quadRes := performQuadraticRegression(averages)
|
||||
result.QuadraticA = quadRes.QuadraticA
|
||||
result.QuadraticB = quadRes.QuadraticB
|
||||
result.QuadraticC = quadRes.QuadraticC
|
||||
if result.Equation != "" {
|
||||
result.Equation += "\n"
|
||||
results[2] = models.RegressionResult{
|
||||
RegressionType: models.QuadraticRegression,
|
||||
TrainId: trainId,
|
||||
Equation: quadRes.Equation,
|
||||
QuadraticA: quadRes.QuadraticA,
|
||||
QuadraticB: quadRes.QuadraticB,
|
||||
QuadraticC: quadRes.QuadraticC,
|
||||
RSquared: quadRes.RSquared,
|
||||
}
|
||||
result.Equation += "二次回归: " + quadRes.Equation
|
||||
|
||||
// 保存计算结果到数据库
|
||||
if err := tc.SaveRegressionResult(trainId, result); err != nil {
|
||||
// 批量保存结果到数据库
|
||||
if err := tc.SaveRegressionResults(trainId, results); err != nil {
|
||||
log.Printf("保存回归结果失败: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// 新增接口:获取回归结果
|
||||
|
||||
Reference in New Issue
Block a user