From 7a2a44e327237a97278ebadf5c9f141121f11ff3 Mon Sep 17 00:00:00 2001 From: laoboli <1293528695@qq.com> Date: Mon, 4 Aug 2025 11:18:07 +0800 Subject: [PATCH] feat: rank query --- controllers/step_train.go | 125 +++++++++++++++++++++----------------- 1 file changed, 69 insertions(+), 56 deletions(-) diff --git a/controllers/step_train.go b/controllers/step_train.go index 0a44421..85ef9ad 100644 --- a/controllers/step_train.go +++ b/controllers/step_train.go @@ -585,41 +585,47 @@ func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值 } -// 存储回归结果到数据库 -func (tc *StepTrainingController) SaveRegressionResult(trainId uint, result models.RegressionResult) error { - result.TrainId = trainId +// 存储回归结果到数据库(支持多种回归类型) +func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error { + return tc.DB.Transaction(func(tx *gorm.DB) error { + for i := range results { + results[i].TrainId = trainId + // 使用复合唯一约束确保每种回归类型只存储一条记录 + err := tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "id"}, + }, + DoUpdates: clause.Assignments(map[string]interface{}{ + "equation": results[i].Equation, + "slope": results[i].Slope, + "intercept": results[i].Intercept, + "log_a": results[i].LogA, + "log_b": results[i].LogB, + "quadratic_a": results[i].QuadraticA, + "quadratic_b": results[i].QuadraticB, + "quadratic_c": results[i].QuadraticC, + "r_squared": results[i].RSquared, + "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), + }), + }).Create(&results[i]).Error - return tc.DB.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "id"}}, - DoUpdates: clause.Assignments(map[string]interface{}{ - "equation": result.Equation, - "slope": result.Slope, - "intercept": result.Intercept, - "log_a": result.LogA, - "log_b": result.LogB, - "quadratic_a": result.QuadraticA, - "quadratic_b": result.QuadraticB, - "quadratic_c": result.QuadraticC, - "r_squared": result.RSquared, - "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), - }), - }).Create(&result).Error + if err != nil { + return err + } + } + return nil + }) } -// 获取或计算回归结果 -func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) (models.RegressionResult, error) { - // 首先尝试从数据库获取 - var result models.RegressionResult - err := tc.DB.Where("train_id = ?", trainId).First(&result).Error +// 获取或计算回归结果(返回多种回归类型列表) +func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) { + // 尝试从数据库获取所有类型的回归结果 + var results []models.RegressionResult + err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error - // 如果找到记录,直接返回 - if err == nil { - return result, nil - } - - // 如果错误不是记录不存在,返回错误 - if !errors.Is(err, gorm.ErrRecordNotFound) { - return models.RegressionResult{}, err + // 如果已存在三种类型的结果,直接返回 + if err == nil && len(results) >= 3 { + return results, nil } // 查询训练记录及相关数据 @@ -629,52 +635,59 @@ func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) (models Preload("HeartRates", "heart_rate_type = ?", 1). Preload("StrideFreqs", "predict_value = ?", 1). First(&record).Error; err != nil { - return models.RegressionResult{}, err + return nil, err } - // 计算心率平均值(模仿Flutter的calculateSegmentAveragesByRealStep) + // 计算心率平均值 averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs) if len(averages) == 0 { - return models.RegressionResult{}, errors.New("无足够数据进行回归计算") + return nil, errors.New("无足够数据进行回归计算") } - // 计算三种回归 - result = models.RegressionResult{ - TrainId: trainId, - } + // 创建三种回归类型的结果 + results = make([]models.RegressionResult, 3) // 线性回归 linearRes := performLinearRegression(averages) - result.Slope = linearRes.Slope - result.Intercept = linearRes.Intercept - result.RSquared = linearRes.RSquared - result.Equation = "线性回归: " + linearRes.Equation + results[0] = models.RegressionResult{ + RegressionType: models.LinearRegression, + TrainId: trainId, + Equation: linearRes.Equation, + Slope: linearRes.Slope, + Intercept: linearRes.Intercept, + RSquared: linearRes.RSquared, + } // 对数回归 logRes := performLogarithmicRegression(averages) - result.LogA = logRes.LogA - result.LogB = logRes.LogB - if result.Equation != "" { - result.Equation += "\n" + results[1] = models.RegressionResult{ + RegressionType: models.LogarithmicRegression, + TrainId: trainId, + Equation: logRes.Equation, + LogA: logRes.LogA, + LogB: logRes.LogB, + RSquared: logRes.RSquared, } - result.Equation += "对数回归: " + logRes.Equation // 二次回归 quadRes := performQuadraticRegression(averages) - result.QuadraticA = quadRes.QuadraticA - result.QuadraticB = quadRes.QuadraticB - result.QuadraticC = quadRes.QuadraticC - if result.Equation != "" { - result.Equation += "\n" + results[2] = models.RegressionResult{ + RegressionType: models.QuadraticRegression, + TrainId: trainId, + Equation: quadRes.Equation, + QuadraticA: quadRes.QuadraticA, + QuadraticB: quadRes.QuadraticB, + QuadraticC: quadRes.QuadraticC, + RSquared: quadRes.RSquared, } - result.Equation += "二次回归: " + quadRes.Equation - // 保存计算结果到数据库 - if err := tc.SaveRegressionResult(trainId, result); err != nil { + // 批量保存结果到数据库 + if err := tc.SaveRegressionResults(trainId, results); err != nil { log.Printf("保存回归结果失败: %v", err) + return nil, err } - return result, nil + return results, nil } // 新增接口:获取回归结果