diff --git a/controllers/step_train.go b/controllers/step_train.go index 6438ed8..0a44421 100644 --- a/controllers/step_train.go +++ b/controllers/step_train.go @@ -2,14 +2,19 @@ package controllers import ( "errors" + "fmt" "github.com/gin-gonic/gin" + "github.com/sajari/regression" "gorm.io/gorm" "gorm.io/gorm/clause" "hr_receiver/config" "hr_receiver/models" + "log" "math" "net/http" + "sort" "strconv" + "strings" ) type StepTrainingController struct { @@ -187,3 +192,615 @@ func (tc *StepTrainingController) GetTrainingRecordByTrainId(c *gin.Context) { "data": record, }) } + +// 定义结构体 +type SpeedSegment struct { + Duration float64 + Speed float64 +} + +// 实现线性回归算法 +func performLinearRegression(averages []map[float64]float64) models.RegressionResult { + if len(averages) == 0 { + return models.RegressionResult{ + Equation: "无数据", + } + } + + // 收集数据点 + var points []struct{ x, y float64 } + for _, m := range averages { + for x, y := range m { + points = append(points, struct{ x, y float64 }{x, y}) + } + } + + // 使用回归库计算 + r := new(regression.Regression) + r.SetObserved("y") + r.SetVar(0, "x") + + for _, p := range points { + r.Train(regression.DataPoint(p.y, []float64{p.x})) + } + + if err := r.Run(); err != nil { + log.Printf("线性回归计算失败: %v", err) + return models.RegressionResult{ + Equation: "计算失败", + } + } + + // 创建结果 + slope := r.Coeff(0) + intercept := r.Coeff(1) + r2 := r.R2 + return models.RegressionResult{ + RegressionType: models.LinearRegression, + Slope: &slope, + Intercept: &intercept, + RSquared: &r2, + Equation: r.Formula, + } +} + +// 实现对数和二次回归算法 +// 对数回归算法 +func performLogarithmicRegression(averages []map[float64]float64) models.RegressionResult { + if len(averages) == 0 { + return models.RegressionResult{ + Equation: "无数据", + } + } + + // 收集数据点 + r := new(regression.Regression) + r.SetObserved("y") + r.SetVar(0, "log(x+1)") + + for _, m := range averages { + for speed, hr := range m { + logSpeed := math.Log(speed + 1) + r.Train(regression.DataPoint(hr, []float64{logSpeed})) + } + } + + if err := r.Run(); err != nil { + log.Printf("对数回归计算失败: %v", err) + return models.RegressionResult{ + Equation: "计算失败", + } + } + + // 创建结果 + logA := r.Coeff(1) + logB := r.Coeff(0) + r2 := r.R2 + return models.RegressionResult{ + RegressionType: models.LogarithmicRegression, + LogA: &logA, + LogB: &logB, + RSquared: &r2, + Equation: r.Formula, + } +} + +// 二次回归算法 +//func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult { +// if len(averages) == 0 { +// return models.RegressionResult{ +// Equation: "无数据", +// } +// } +// +// // 收集数据点 +// r := new(regression.Regression) +// r.SetObserved("y") +// r.SetVar(0, "x") +// r.SetVar(1, "x²") +// +// for _, m := range averages { +// for speed, hr := range m { +// speedSq := math.Pow(speed, 2) +// r.Train(regression.DataPoint(hr, []float64{speed, speedSq})) +// } +// } +// +// if err := r.Run(); err != nil { +// log.Printf("二次回归计算失败: %v", err) +// return models.RegressionResult{ +// Equation: "计算失败", +// } +// } +// +// // 创建结果 +// a := r.Coeff(2) +// b := r.Coeff(1) +// c := r.Coeff(0) +// r2 := r.R2 +// return models.RegressionResult{ +// RegressionType: models.QuadraticRegression, +// QuadraticA: &a, +// QuadraticB: &b, +// QuadraticC: &c, +// RSquared: &r2, +// Equation: r.Formula, +// } +//} + +func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult { + if len(averages) == 0 { + return models.RegressionResult{ + Equation: "无数据", + } + } + + // 步骤1:收集所有数据点(与Flutter一致) + var xValues []float64 + var yValues []float64 + for _, m := range averages { + for speed, hr := range m { + xValues = append(xValues, speed) + yValues = append(yValues, hr) + } + } + n := float64(len(xValues)) + + // 步骤2:计算各项和(完全匹配Flutter的计算) + var sumX, sumY, sumX2, sumX3, sumX4, sumXY, sumX2Y float64 + for i := 0; i < len(xValues); i++ { + x := xValues[i] + y := yValues[i] + x2 := x * x + x3 := x2 * x + x4 := x3 * x + + sumX += x + sumY += y + sumX2 += x2 + sumX3 += x3 + sumX4 += x4 + sumXY += x * y + sumX2Y += x2 * y + } + + // 步骤3:构建正规方程矩阵(与Flutter完全一致) + matrix := [3][3]float64{ + {n, sumX, sumX2}, + {sumX, sumX2, sumX3}, + {sumX2, sumX3, sumX4}, + } + vector := []float64{sumY, sumXY, sumX2Y} + + // 步骤4:计算矩阵行列式(复制Flutter的determinant3x3逻辑) + det := matrix[0][0]*(matrix[1][1]*matrix[2][2]-matrix[1][2]*matrix[2][1]) - + matrix[0][1]*(matrix[1][0]*matrix[2][2]-matrix[1][2]*matrix[2][0]) + + matrix[0][2]*(matrix[1][0]*matrix[2][1]-matrix[1][1]*matrix[2][0]) + + if det == 0 { + return models.RegressionResult{ + Equation: "无法拟合", + } + } + + // 步骤5:克莱姆法则求解系数(顺序与Flutter一致) + // 注意:最终系数顺序 a=二次项, b=一次项, c=常数项 + c := det3x3([3][3]float64{ + {vector[0], matrix[0][1], matrix[0][2]}, + {vector[1], matrix[1][1], matrix[1][2]}, + {vector[2], matrix[2][1], matrix[2][2]}, + }) / det + + b := det3x3([3][3]float64{ + {matrix[0][0], vector[0], matrix[0][2]}, + {matrix[1][0], vector[1], matrix[1][2]}, + {matrix[2][0], vector[2], matrix[2][2]}, + }) / det + + a := det3x3([3][3]float64{ + {matrix[0][0], matrix[0][1], vector[0]}, + {matrix[1][0], matrix[1][1], vector[1]}, + {matrix[2][0], matrix[2][1], vector[2]}, + }) / det + + // 步骤6:计算R平方(完全复制Flutter的计算逻辑) + var ssRes, ssTot float64 + meanY := sumY / n + for i := 0; i < len(xValues); i++ { + x := xValues[i] + y := yValues[i] + yPred := a*x*x + b*x + c + ssRes += math.Pow(y-yPred, 2) + ssTot += math.Pow(y-meanY, 2) + } + rSquared := 0.0 + if ssTot != 0 { + rSquared = 1 - ssRes/ssTot + } + + // 步骤7:格式化公式字符串(与Flutter格式完全一致) + equation := formatEquation(a, b, c, rSquared) + + return models.RegressionResult{ + RegressionType: models.QuadraticRegression, + QuadraticA: &a, + QuadraticB: &b, + QuadraticC: &c, + RSquared: &rSquared, + Equation: equation, + } +} + +// 3x3行列式计算(与Flutter实现相同) +func det3x3(m [3][3]float64) float64 { + return m[0][0]*(m[1][1]*m[2][2]-m[1][2]*m[2][1]) - + m[0][1]*(m[1][0]*m[2][2]-m[1][2]*m[2][0]) + + m[0][2]*(m[1][0]*m[2][1]-m[1][1]*m[2][0]) +} + +// 公式格式化(完全匹配Flutter格式) +func formatEquation(a, b, c, r2 float64) string { + // 保留4位小数 + aStr := fmt.Sprintf("%.4f", a) + bStr := fmt.Sprintf("%.4f", b) + cStr := fmt.Sprintf("%.4f", c) + r2Str := fmt.Sprintf("%.4f", r2) + + builder := strings.Builder{} + builder.WriteString("y = ") + + // 处理二次项 + if a >= 0 { + builder.WriteString(aStr + " x²") + } else { + builder.WriteString("-" + strings.TrimPrefix(aStr, "-") + " x²") + } + + // 处理一次项 + if b >= 0 { + builder.WriteString(" + " + bStr + " x") + } else { + builder.WriteString(" - " + strings.TrimPrefix(bStr, "-") + " x") + } + + // 处理常数项 + if c >= 0 { + builder.WriteString(" + " + cStr) + } else { + builder.WriteString(" - " + strings.TrimPrefix(cStr, "-")) + } + + builder.WriteString(" (R² = " + r2Str + ")") + return builder.String() +} + +// 步频数据转换为速度段 +func convertStrideFrequencyToSegments(steps []models.StepStrideFreq) []SpeedSegment { + if len(steps) == 0 { + return []SpeedSegment{} + } + + // 过滤零值并排序 + validSteps := make([]models.StepStrideFreq, 0, len(steps)) + for _, s := range steps { + if s.Value > 0 { + validSteps = append(validSteps, s) + } + } + + if len(validSteps) == 0 { + return []SpeedSegment{} + } + + // 按时间排序 + for i := 0; i < len(validSteps)-1; i++ { + for j := i + 1; j < len(validSteps); j++ { + if validSteps[i].Time > validSteps[j].Time { + validSteps[i], validSteps[j] = validSteps[j], validSteps[i] + } + } + } + + // 创建速度段 + segments := make([]SpeedSegment, 0) + startTime := validSteps[0].Time + currentValue := validSteps[0].Value + + for i := 1; i < len(validSteps); i++ { + if validSteps[i].Value != currentValue { + duration := float64(validSteps[i].Time-startTime) / 1000.0 + if duration > 0 { + segments = append(segments, SpeedSegment{ + Duration: duration, + Speed: float64(currentValue), + }) + } + startTime = validSteps[i].Time + currentValue = validSteps[i].Value + } + } + + // 添加最后一个段 + if len(validSteps) > 0 { + duration := float64(validSteps[len(validSteps)-1].Time-startTime) / 1000.0 + if duration > 0 { + segments = append(segments, SpeedSegment{ + Duration: duration, + Speed: float64(currentValue), + }) + } + } + + return segments +} + +// 计算区段平均值 +func calculateSegmentAverages(heartRates []models.StepHeartRate, segments []SpeedSegment, errorThreshold int) []map[float64]float64 { + currentTime := 0.0 + results := make([]map[float64]float64, 0) + + for _, seg := range segments { + minRequired := 60 + (60 - float64(errorThreshold)) + + // 跳过不满足条件的区段 + if seg.Duration < minRequired { + currentTime += seg.Duration + continue + } + + // 计算时间窗口 + startSec := currentTime + 60 + endSec := currentTime + if seg.Duration >= 120 { + endSec = currentTime + 120 + } else { + endSec = currentTime + 120 - float64(errorThreshold) + } + + // 收集该区段的心率数据 + sum, count := 0, 0 + for _, hr := range heartRates { + sec := float64(hr.Time) / 1000.0 + if sec >= startSec && sec <= endSec { + sum += hr.Value + count++ + } + } + + // 计算平均值 + if count > 0 { + avg := float64(sum) / float64(count) + results = append(results, map[float64]float64{seg.Speed: avg}) + } + + currentTime += seg.Duration + } + + return results +} + +// 计算步频区段的心率平均值 +func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps []models.StepStrideFreq) []map[float64]float64 { + segments := convertStrideFrequencyToSegments(steps) + return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值 +} + +// 存储回归结果到数据库 +func (tc *StepTrainingController) SaveRegressionResult(trainId uint, result models.RegressionResult) error { + result.TrainId = trainId + + return tc.DB.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "id"}}, + DoUpdates: clause.Assignments(map[string]interface{}{ + "equation": result.Equation, + "slope": result.Slope, + "intercept": result.Intercept, + "log_a": result.LogA, + "log_b": result.LogB, + "quadratic_a": result.QuadraticA, + "quadratic_b": result.QuadraticB, + "quadratic_c": result.QuadraticC, + "r_squared": result.RSquared, + "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), + }), + }).Create(&result).Error +} + +// 获取或计算回归结果 +func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) (models.RegressionResult, error) { + // 首先尝试从数据库获取 + var result models.RegressionResult + err := tc.DB.Where("train_id = ?", trainId).First(&result).Error + + // 如果找到记录,直接返回 + if err == nil { + return result, nil + } + + // 如果错误不是记录不存在,返回错误 + if !errors.Is(err, gorm.ErrRecordNotFound) { + return models.RegressionResult{}, err + } + + // 查询训练记录及相关数据 + var record models.StepTrainRecord + if err := tc.DB. + Where("train_id = ?", uint(trainId)). + Preload("HeartRates", "heart_rate_type = ?", 1). + Preload("StrideFreqs", "predict_value = ?", 1). + First(&record).Error; err != nil { + return models.RegressionResult{}, err + } + + // 计算心率平均值(模仿Flutter的calculateSegmentAveragesByRealStep) + averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs) + if len(averages) == 0 { + return models.RegressionResult{}, errors.New("无足够数据进行回归计算") + } + + // 计算三种回归 + result = models.RegressionResult{ + TrainId: trainId, + } + + // 线性回归 + linearRes := performLinearRegression(averages) + result.Slope = linearRes.Slope + result.Intercept = linearRes.Intercept + result.RSquared = linearRes.RSquared + result.Equation = "线性回归: " + linearRes.Equation + + // 对数回归 + logRes := performLogarithmicRegression(averages) + result.LogA = logRes.LogA + result.LogB = logRes.LogB + if result.Equation != "" { + result.Equation += "\n" + } + result.Equation += "对数回归: " + logRes.Equation + + // 二次回归 + quadRes := performQuadraticRegression(averages) + result.QuadraticA = quadRes.QuadraticA + result.QuadraticB = quadRes.QuadraticB + result.QuadraticC = quadRes.QuadraticC + if result.Equation != "" { + result.Equation += "\n" + } + result.Equation += "二次回归: " + quadRes.Equation + + // 保存计算结果到数据库 + if err := tc.SaveRegressionResult(trainId, result); err != nil { + log.Printf("保存回归结果失败: %v", err) + } + + return result, nil +} + +// 新增接口:获取回归结果 +func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) { + trainIdStr := c.Param("trainId") + tid, err := strconv.ParseUint(trainIdStr, 10, 32) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"}) + return + } + + result, err := tc.GetOrCalculateRegression(uint(tid)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "获取成功", + "data": result, + }) +} + +// 获取训练记录的排名 +func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) { + // 解析参数 + trainIdStr := c.Param("trainId") + + regressionTypeStr := c.Query("type") + regressionType, err := strconv.Atoi(regressionTypeStr) // 字符串转整型 + if err != nil { + // 转换失败时返回400错误 + c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"}) + return + } + + regType := models.RegressionType(regressionType) + + // 验证回归类型 + if regType != models.LinearRegression && regType != models.QuadraticRegression { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"}) + return + } + + // 转换trainId + tid, err := strconv.ParseUint(trainIdStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"}) + return + } + trainId := uint(tid) + + // 确保回归结果存在 + if _, err := tc.GetOrCalculateRegression(trainId); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()}) + return + } + + // 获取所有记录用于排名 + var records []models.RegressionResult + query := tc.DB.Model(&models.RegressionResult{}) + switch regType { + case models.LinearRegression: + query = query.Where("slope IS NOT NULL").Select("train_id, slope AS metric") + case models.QuadraticRegression: + query = query.Where("quadratic_a IS NOT NULL"). + Select("train_id, ABS(quadratic_a) AS metric") + } + if err := query.Find(&records).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"}) + return + } + + // 处理无数据情况 + if len(records) == 0 { + c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"}) + return + } + + // 排序逻辑 + sort.Slice(records, func(i, j int) bool { + if regType == models.LinearRegression { + return *records[i].Slope < *records[j].Slope + } + return math.Abs(*records[i].QuadraticA) > math.Abs(*records[j].QuadraticA) // 二次回归按绝对值降序 + }) + + // 计算排名(处理并列) + currentRank := 1 + rankMap := make(map[uint]int) + for i, record := range records { + // 处理第一个记录 + if i == 0 { + rankMap[record.TrainId] = currentRank + continue + } + + if regType == models.LinearRegression { + if records[i].Slope != records[i-1].Slope { + currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1 + } + rankMap[record.TrainId] = currentRank + } else { + if records[i].QuadraticA != records[i-1].QuadraticA { + currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1 + } + rankMap[record.TrainId] = currentRank + } + // 检测值是否变化 + } + + // 获取当前训练记录的排名 + rank, exists := rankMap[trainId] + if !exists { + c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"}) + return + } + + // 返回响应 + c.JSON(http.StatusOK, gin.H{ + "message": "排名查询成功", + "data": gin.H{ + "trainId": trainId, + "type": regressionType, + "rank": rank, + "total": len(records), + }, + }) +} diff --git a/go.mod b/go.mod index a545bb8..7180d14 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.3 require ( github.com/gin-gonic/gin v1.10.0 github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/sajari/regression v1.0.1 github.com/spf13/viper v1.20.0 gonum.org/v1/gonum v0.16.0 gorm.io/driver/postgres v1.5.11 diff --git a/go.sum b/go.sum index 22c8773..7b91dc4 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,8 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo= github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k= +github.com/sajari/regression v1.0.1 h1:iTVc6ZACGCkoXC+8NdqH5tIreslDTT/bXxT6OmHR5PE= +github.com/sajari/regression v1.0.1/go.mod h1:NeG/XTW1lYfGY7YV/Z0nYDV/RGh3wxwd1yW46835flM= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= diff --git a/main.go b/main.go index 7b0cd02..4c86365 100644 --- a/main.go +++ b/main.go @@ -22,9 +22,11 @@ func main() { &models.BeltAnalysis{}, &models.StepTrainRecord{}, &models.StepHeartRate{}, - &models.StepStrideFreq{}) + &models.StepStrideFreq{}, + &models.RegressionResult{}, + ) // 启动服务 r := routes.SetupRouter() - r.Run(":8080") + r.Run(":8000") } diff --git a/models/step_train.go b/models/step_train.go index c163e0a..4262acf 100644 --- a/models/step_train.go +++ b/models/step_train.go @@ -36,3 +36,33 @@ type StepTrainRecord struct { HeartRates []StepHeartRate `gorm:"foreignKey:TrainId;references:TrainId" json:"heartRates"` StrideFreqs []StepStrideFreq `gorm:"foreignKey:TrainId;references:TrainId" json:"strideFreqs"` } + +type RegressionType int + +const ( + LinearRegression RegressionType = iota + 1 + LogarithmicRegression + QuadraticRegression +) + +type RegressionResult struct { + gorm.Model + RegressionType RegressionType `gorm:"column:regression_type;index" json:"regressionType"` // 训练记录ID + TrainId uint `gorm:"column:train_id;index" json:"trainId"` // 训练记录ID + Equation string `gorm:"type:text" json:"equation"` // 回归方程 + + // 线性回归系数 + Slope *float64 `gorm:"column:slope" json:"slope"` + Intercept *float64 `gorm:"column:intercept" json:"intercept"` + + // 对数回归系数 + LogA *float64 `gorm:"column:log_a" json:"logA"` + LogB *float64 `gorm:"column:log_b" json:"logB"` + + // 二次回归系数 + QuadraticA *float64 `gorm:"column:quadratic_a" json:"quadraticA"` + QuadraticB *float64 `gorm:"column:quadratic_b" json:"quadraticB"` + QuadraticC *float64 `gorm:"column:quadratic_c" json:"quadraticC"` + + RSquared *float64 `gorm:"column:r_squared" json:"rSquared"` // R平方值 +} diff --git a/routes/routes.go b/routes/routes.go index 0e2370d..f128378 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -27,6 +27,7 @@ func SetupRouter() *gin.Engine { steps.POST("", stepTrainController.CreateTrainingRecord) steps.GET("train-records", stepTrainController.GetTrainingRecords) steps.GET("train-data/:trainId", stepTrainController.GetTrainingRecordByTrainId) + steps.GET("train-rank/:trainId", stepTrainController.GetTrainingRank) // 可扩展其他路由:GET, PUT, DELETE等 } auth := v1.Group("/auth")