package controllers import ( "errors" "fmt" "github.com/gin-gonic/gin" "github.com/sajari/regression" "gorm.io/gorm" "gorm.io/gorm/clause" "hr_receiver/config" "hr_receiver/models" "log" "math" "net/http" "strconv" "strings" ) type StepTrainingController struct { DB *gorm.DB } func NewStepTrainingController() *StepTrainingController { return &StepTrainingController{DB: config.DB} } // 接收训练记录 func (tc *StepTrainingController) CreateTrainingRecord(c *gin.Context) { var record models.StepTrainRecord // 绑定并验证JSON数据 if err := c.ShouldBindJSON(&record); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // 使用事务保存数据[4](@ref) err := tc.DB.Transaction(func(tx *gorm.DB) error { // 保存主记录 if err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列 DoUpdates: clause.Assignments(map[string]interface{}{ "max_heart_rate": record.MaxHeartRate, "start_time": record.StartTime, "end_time": record.EndTime, "duration": record.Duration, "dead_zone": record.DeadZone, "name": record.Name, "evaluation": record.Evaluation, }), }).Omit("HeartRates", "StrideFreqs").Create(&record).Error; err != nil { return err } // 保存关联的心率数据 for i := range record.HeartRates { if err := tx.Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列 DoUpdates: clause.Assignments(map[string]interface{}{"heart_rate_type": record.HeartRates[i].HeartRateType, "value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}), }, ).Create(&record.HeartRates[i]).Error; err != nil { return err } } for i := range record.StrideFreqs { if err := tx.Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列 DoUpdates: clause.Assignments(map[string]interface{}{"value": record.StrideFreqs[i].Value, "time": record.StrideFreqs[i].Time}), }, ).Create(&record.StrideFreqs[i]).Error; err != nil { return err } } return nil }) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusCreated, gin.H{ "message": "数据保存成功", "id": record.TrainId, }) } func (tc *StepTrainingController) GetTrainingRecords(c *gin.Context) { // 定义分页参数结构 type PaginationParams struct { PageNum int `form:"pageNum,default=1"` // 页码,默认第一页 PageSize int `form:"pageSize,default=10"` // 每页数量,默认10条 } var params PaginationParams if err := c.ShouldBindQuery(¶ms); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // 验证分页参数有效性 if params.PageNum < 1 { params.PageNum = 1 } if params.PageSize < 1 || params.PageSize > 100 { params.PageSize = 10 } // 计算偏移量 offset := (params.PageNum - 1) * params.PageSize var ( records []models.StepTrainRecord totalRows int64 ) // 获取总记录数 if err := tc.DB.Model(&models.StepTrainRecord{}).Count(&totalRows).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "获取记录总数失败"}) return } // 查询分页数据(按开始时间倒序排列) result := tc.DB. Order("start_time DESC"). // 按开始时间倒序 Offset(offset). Limit(params.PageSize). Find(&records) if result.Error != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()}) return } // 计算总页数 totalPages := int(math.Ceil(float64(totalRows) / float64(params.PageSize))) c.JSON(http.StatusOK, gin.H{ "message": "查询成功", "data": gin.H{ "list": records, "pagination": gin.H{ "currentPage": params.PageNum, "pageSize": params.PageSize, "totalPage": totalPages, "totalList": totalRows, }, }, }) } func (tc *StepTrainingController) GetTrainingRecordByTrainId(c *gin.Context) { // 从URL路径参数获取trainId trainId := c.Param("trainId") if trainId == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "训练ID不能为空"}) return } // 将字符串trainId转换为uint类型 tid, err := strconv.ParseInt(trainId, 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID格式"}) return } var record models.StepTrainRecord // 查询主记录并预加载关联的心率和步频数据 result := tc.DB.Where("train_id = ?", uint(tid)). Preload("HeartRates"). Preload("StrideFreqs"). First(&record) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { c.JSON(http.StatusNotFound, gin.H{"error": "训练记录不存在"}) } else { c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()}) } return } // 成功返回数据 c.JSON(http.StatusOK, gin.H{ "message": "查询成功", "data": record, }) } // 定义结构体 type SpeedSegment struct { Duration float64 Speed float64 } // 实现线性回归算法 func performLinearRegression(averages []map[float64]float64) models.RegressionResult { if len(averages) == 0 { return models.RegressionResult{ Equation: "无数据", } } // 收集数据点 var points []struct{ x, y float64 } for _, m := range averages { for x, y := range m { points = append(points, struct{ x, y float64 }{x, y}) } } // 使用回归库计算 r := new(regression.Regression) r.SetObserved("y") r.SetVar(0, "x") for _, p := range points { r.Train(regression.DataPoint(p.y, []float64{p.x})) } if err := r.Run(); err != nil { log.Printf("线性回归计算失败: %v", err) return models.RegressionResult{ Equation: "计算失败", } } // 创建结果 slope := r.Coeff(1) intercept := r.Coeff(0) r2 := r.R2 return models.RegressionResult{ RegressionType: models.LinearRegression, Slope: &slope, Intercept: &intercept, RSquared: &r2, Equation: r.Formula, } } // 实现对数和二次回归算法 // 对数回归算法 func performLogarithmicRegression(averages []map[float64]float64) models.RegressionResult { if len(averages) == 0 { return models.RegressionResult{ Equation: "无数据", } } // 收集数据点 r := new(regression.Regression) r.SetObserved("y") r.SetVar(0, "log(x+1)") for _, m := range averages { for speed, hr := range m { logSpeed := math.Log(speed + 1) r.Train(regression.DataPoint(hr, []float64{logSpeed})) } } if err := r.Run(); err != nil { log.Printf("对数回归计算失败: %v", err) return models.RegressionResult{ Equation: "计算失败", } } // 创建结果 logA := r.Coeff(1) logB := r.Coeff(0) r2 := r.R2 return models.RegressionResult{ RegressionType: models.LogarithmicRegression, LogA: &logA, LogB: &logB, RSquared: &r2, Equation: r.Formula, } } // 二次回归算法 //func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult { // if len(averages) == 0 { // return models.RegressionResult{ // Equation: "无数据", // } // } // // // 收集数据点 // r := new(regression.Regression) // r.SetObserved("y") // r.SetVar(0, "x") // r.SetVar(1, "x²") // // for _, m := range averages { // for speed, hr := range m { // speedSq := math.Pow(speed, 2) // r.Train(regression.DataPoint(hr, []float64{speed, speedSq})) // } // } // // if err := r.Run(); err != nil { // log.Printf("二次回归计算失败: %v", err) // return models.RegressionResult{ // Equation: "计算失败", // } // } // // // 创建结果 // a := r.Coeff(2) // b := r.Coeff(1) // c := r.Coeff(0) // r2 := r.R2 // return models.RegressionResult{ // RegressionType: models.QuadraticRegression, // QuadraticA: &a, // QuadraticB: &b, // QuadraticC: &c, // RSquared: &r2, // Equation: r.Formula, // } //} func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult { if len(averages) == 0 { return models.RegressionResult{ Equation: "无数据", } } // 步骤1:收集所有数据点(与Flutter一致) var xValues []float64 var yValues []float64 for _, m := range averages { for speed, hr := range m { xValues = append(xValues, speed) yValues = append(yValues, hr) } } n := float64(len(xValues)) // 步骤2:计算各项和(完全匹配Flutter的计算) var sumX, sumY, sumX2, sumX3, sumX4, sumXY, sumX2Y float64 for i := 0; i < len(xValues); i++ { x := xValues[i] y := yValues[i] x2 := x * x x3 := x2 * x x4 := x3 * x sumX += x sumY += y sumX2 += x2 sumX3 += x3 sumX4 += x4 sumXY += x * y sumX2Y += x2 * y } // 步骤3:构建正规方程矩阵(与Flutter完全一致) matrix := [3][3]float64{ {n, sumX, sumX2}, {sumX, sumX2, sumX3}, {sumX2, sumX3, sumX4}, } vector := []float64{sumY, sumXY, sumX2Y} // 步骤4:计算矩阵行列式(复制Flutter的determinant3x3逻辑) det := matrix[0][0]*(matrix[1][1]*matrix[2][2]-matrix[1][2]*matrix[2][1]) - matrix[0][1]*(matrix[1][0]*matrix[2][2]-matrix[1][2]*matrix[2][0]) + matrix[0][2]*(matrix[1][0]*matrix[2][1]-matrix[1][1]*matrix[2][0]) if det == 0 { return models.RegressionResult{ Equation: "无法拟合", } } // 步骤5:克莱姆法则求解系数(顺序与Flutter一致) // 注意:最终系数顺序 a=二次项, b=一次项, c=常数项 c := det3x3([3][3]float64{ {vector[0], matrix[0][1], matrix[0][2]}, {vector[1], matrix[1][1], matrix[1][2]}, {vector[2], matrix[2][1], matrix[2][2]}, }) / det b := det3x3([3][3]float64{ {matrix[0][0], vector[0], matrix[0][2]}, {matrix[1][0], vector[1], matrix[1][2]}, {matrix[2][0], vector[2], matrix[2][2]}, }) / det a := det3x3([3][3]float64{ {matrix[0][0], matrix[0][1], vector[0]}, {matrix[1][0], matrix[1][1], vector[1]}, {matrix[2][0], matrix[2][1], vector[2]}, }) / det // 步骤6:计算R平方(完全复制Flutter的计算逻辑) var ssRes, ssTot float64 meanY := sumY / n for i := 0; i < len(xValues); i++ { x := xValues[i] y := yValues[i] yPred := a*x*x + b*x + c ssRes += math.Pow(y-yPred, 2) ssTot += math.Pow(y-meanY, 2) } rSquared := 0.0 if ssTot != 0 { rSquared = 1 - ssRes/ssTot } // 步骤7:格式化公式字符串(与Flutter格式完全一致) equation := formatEquation(a, b, c, rSquared) return models.RegressionResult{ RegressionType: models.QuadraticRegression, QuadraticA: &a, QuadraticB: &b, QuadraticC: &c, RSquared: &rSquared, Equation: equation, } } // 3x3行列式计算(与Flutter实现相同) func det3x3(m [3][3]float64) float64 { return m[0][0]*(m[1][1]*m[2][2]-m[1][2]*m[2][1]) - m[0][1]*(m[1][0]*m[2][2]-m[1][2]*m[2][0]) + m[0][2]*(m[1][0]*m[2][1]-m[1][1]*m[2][0]) } // 公式格式化(完全匹配Flutter格式) func formatEquation(a, b, c, r2 float64) string { // 保留4位小数 aStr := fmt.Sprintf("%.4f", a) bStr := fmt.Sprintf("%.4f", b) cStr := fmt.Sprintf("%.4f", c) r2Str := fmt.Sprintf("%.4f", r2) builder := strings.Builder{} builder.WriteString("y = ") // 处理二次项 if a >= 0 { builder.WriteString(aStr + " x²") } else { builder.WriteString("-" + strings.TrimPrefix(aStr, "-") + " x²") } // 处理一次项 if b >= 0 { builder.WriteString(" + " + bStr + " x") } else { builder.WriteString(" - " + strings.TrimPrefix(bStr, "-") + " x") } // 处理常数项 if c >= 0 { builder.WriteString(" + " + cStr) } else { builder.WriteString(" - " + strings.TrimPrefix(cStr, "-")) } builder.WriteString(" (R² = " + r2Str + ")") return builder.String() } // 步频数据转换为速度段 func convertStrideFrequencyToSegments(steps []models.StepStrideFreq) []SpeedSegment { if len(steps) == 0 { return []SpeedSegment{} } // 过滤零值并排序 validSteps := make([]models.StepStrideFreq, 0, len(steps)) for _, s := range steps { if s.Value > 0 { validSteps = append(validSteps, s) } } if len(validSteps) == 0 { return []SpeedSegment{} } // 按时间排序 for i := 0; i < len(validSteps)-1; i++ { for j := i + 1; j < len(validSteps); j++ { if validSteps[i].Time > validSteps[j].Time { validSteps[i], validSteps[j] = validSteps[j], validSteps[i] } } } // 创建速度段 segments := make([]SpeedSegment, 0) startTime := validSteps[0].Time currentValue := validSteps[0].Value for i := 1; i < len(validSteps); i++ { if validSteps[i].Value != currentValue { duration := float64(validSteps[i].Time-startTime) / 1000.0 if duration > 0 { segments = append(segments, SpeedSegment{ Duration: duration, Speed: float64(currentValue), }) } startTime = validSteps[i].Time currentValue = validSteps[i].Value } } // 添加最后一个段 if len(validSteps) > 0 { duration := float64(validSteps[len(validSteps)-1].Time-startTime) / 1000.0 if duration > 0 { segments = append(segments, SpeedSegment{ Duration: duration, Speed: float64(currentValue), }) } } return segments } // 计算区段平均值 func calculateSegmentAverages(heartRates []models.StepHeartRate, segments []SpeedSegment, errorThreshold int) []map[float64]float64 { currentTime := 0.0 results := make([]map[float64]float64, 0) for _, seg := range segments { minRequired := 60 + (60 - float64(errorThreshold)) // 跳过不满足条件的区段 if seg.Duration < minRequired { currentTime += seg.Duration continue } // 计算时间窗口 startSec := currentTime + 60 endSec := currentTime if seg.Duration >= 120 { endSec = currentTime + 120 } else { endSec = currentTime + 120 - float64(errorThreshold) } // 收集该区段的心率数据 sum, count := 0, 0 for _, hr := range heartRates { sec := float64(hr.Time) / 1000.0 if sec >= startSec && sec <= endSec { sum += hr.Value count++ } } // 计算平均值 if count > 0 { avg := float64(sum) / float64(count) results = append(results, map[float64]float64{seg.Speed: avg}) } currentTime += seg.Duration } return results } // 计算步频区段的心率平均值 func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps []models.StepStrideFreq) []map[float64]float64 { segments := convertStrideFrequencyToSegments(steps) return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值 } // 存储回归结果到数据库(支持多种回归类型) func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error { return tc.DB.Transaction(func(tx *gorm.DB) error { for i := range results { results[i].TrainId = trainId // 使用复合唯一约束确保每种回归类型只存储一条记录 err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{ {Name: "id"}, }, DoUpdates: clause.Assignments(map[string]interface{}{ "equation": results[i].Equation, "slope": results[i].Slope, "intercept": results[i].Intercept, "log_a": results[i].LogA, "log_b": results[i].LogB, "quadratic_a": results[i].QuadraticA, "quadratic_b": results[i].QuadraticB, "quadratic_c": results[i].QuadraticC, "r_squared": results[i].RSquared, "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), }), }).Create(&results[i]).Error if err != nil { return err } } return nil }) } // 获取或计算回归结果(返回多种回归类型列表) func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) { // 尝试从数据库获取所有类型的回归结果 var results []models.RegressionResult err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error // 如果已存在三种类型的结果,直接返回 if err == nil && len(results) >= 3 { return results, nil } // 查询训练记录及相关数据 var record models.StepTrainRecord if err := tc.DB. Where("train_id = ?", uint(trainId)). Preload("HeartRates", "heart_rate_type = ?", 1). Preload("StrideFreqs", "predict_value = ?", 1). First(&record).Error; err != nil { return nil, err } // 计算心率平均值 averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs) if len(averages) == 0 { return nil, errors.New("无足够数据进行回归计算") } // 创建三种回归类型的结果 results = make([]models.RegressionResult, 3) // 线性回归 linearRes := performLinearRegression(averages) results[0] = models.RegressionResult{ RegressionType: models.LinearRegression, TrainId: trainId, Equation: linearRes.Equation, Slope: linearRes.Slope, Intercept: linearRes.Intercept, RSquared: linearRes.RSquared, } // 对数回归 logRes := performLogarithmicRegression(averages) results[1] = models.RegressionResult{ RegressionType: models.LogarithmicRegression, TrainId: trainId, Equation: logRes.Equation, LogA: logRes.LogA, LogB: logRes.LogB, RSquared: logRes.RSquared, } // 二次回归 quadRes := performQuadraticRegression(averages) results[2] = models.RegressionResult{ RegressionType: models.QuadraticRegression, TrainId: trainId, Equation: quadRes.Equation, QuadraticA: quadRes.QuadraticA, QuadraticB: quadRes.QuadraticB, QuadraticC: quadRes.QuadraticC, RSquared: quadRes.RSquared, } // 批量保存结果到数据库 if err := tc.SaveRegressionResults(trainId, results); err != nil { log.Printf("保存回归结果失败: %v", err) return nil, err } return results, nil } // 新增接口:获取回归结果 func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) { trainIdStr := c.Param("trainId") tid, err := strconv.ParseUint(trainIdStr, 10, 32) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"}) return } result, err := tc.GetOrCalculateRegression(uint(tid)) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "message": "获取成功", "data": result, }) } func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) { // 参数解析 trainIdStr := c.Param("trainId") regressionTypeStr := c.Query("type") regressionType, err := strconv.Atoi(regressionTypeStr) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"}) return } // 验证回归类型 regType := models.RegressionType(regressionType) if regType != models.LinearRegression && regType != models.QuadraticRegression { c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"}) return } // 转换训练ID tid, err := strconv.ParseUint(trainIdStr, 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"}) return } trainId := uint(tid) // 确保回归结果存在 if _, err := tc.GetOrCalculateRegression(trainId); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()}) return } // 获取指定训练的基准值 var baseValue float64 baseQuery := tc.DB.Model(&models.RegressionResult{}). Select(getValueColumn(regType)). Where("train_id = ?", trainId) if err := baseQuery.Row().Scan(&baseValue); err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { c.JSON(http.StatusNotFound, gin.H{"error": "指定的训练数据不存在"}) return } c.JSON(http.StatusInternalServerError, gin.H{"error": "获取基准数据失败"}) return } // 在数据库中进行排名计算 var rank struct { BetterCount int64 Total int64 } // 动态生成比较条件 betterCondition := fmt.Sprintf("%s %s ?", getValueColumn(regType), getComparisonOperator(regType)) totalQuery := tc.DB.Model(&models.RegressionResult{}). Where(getTypeCondition(regType)) if err := totalQuery.Count(&rank.Total).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "统计总数失败"}) return } if err := tc.DB.Model(&models.RegressionResult{}). Where(getTypeCondition(regType)). Where(betterCondition, baseValue). Count(&rank.BetterCount).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "计算排名失败"}) return } // 计算实际排名 (并列排名) currentRank := rank.BetterCount + 1 // 返回响应 c.JSON(http.StatusOK, gin.H{ "message": "排名查询成功", "data": gin.H{ "trainId": trainId, "type": regressionType, "rank": currentRank, "total": rank.Total, }, }) } // 辅助函数:获取排序字段名 func getValueColumn(regType models.RegressionType) string { switch regType { case models.LinearRegression: return "slope" case models.QuadraticRegression: return "ABS(quadratic_a)" // 计算绝对值 default: return "" } } // 辅助函数:获取比较操作符 func getComparisonOperator(regType models.RegressionType) string { switch regType { case models.LinearRegression: return "<" // 线性回归:值越小越好 case models.QuadraticRegression: return ">" // 二次回归:绝对值越大越好 default: return "" } } // 辅助函数:获取类型条件 func getTypeCondition(regType models.RegressionType) string { switch regType { case models.LinearRegression: return "slope IS NOT NULL" case models.QuadraticRegression: return "quadratic_a IS NOT NULL AND quadratic_a < 0" // 确保是负值 default: return "" } } // 辅助函数:比较浮点指针(用于线性回归) func compareFloatPtr(a, b *float64, ascending bool) bool { if a == nil && b == nil { return false } if a == nil { return false // 空值排最后 } if b == nil { return true // 非空值排前 } if ascending { return *a < *b } return *a > *b } // 辅助函数:比较二次项系数(用于二次回归) func compareQuadraticA(a, b *float64) bool { if a == nil && b == nil { return false } if a == nil { return false } if b == nil { return true } // 比较绝对值(a和b都是负值,所以取绝对值后大的排前面) return math.Abs(*a) > math.Abs(*b) }