feat: regression
This commit is contained in:
@ -2,14 +2,19 @@ package controllers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/sajari/regression"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"hr_receiver/config"
|
"hr_receiver/config"
|
||||||
"hr_receiver/models"
|
"hr_receiver/models"
|
||||||
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StepTrainingController struct {
|
type StepTrainingController struct {
|
||||||
@ -187,3 +192,615 @@ func (tc *StepTrainingController) GetTrainingRecordByTrainId(c *gin.Context) {
|
|||||||
"data": record,
|
"data": record,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 定义结构体
|
||||||
|
type SpeedSegment struct {
|
||||||
|
Duration float64
|
||||||
|
Speed float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现线性回归算法
|
||||||
|
func performLinearRegression(averages []map[float64]float64) models.RegressionResult {
|
||||||
|
if len(averages) == 0 {
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "无数据",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集数据点
|
||||||
|
var points []struct{ x, y float64 }
|
||||||
|
for _, m := range averages {
|
||||||
|
for x, y := range m {
|
||||||
|
points = append(points, struct{ x, y float64 }{x, y})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用回归库计算
|
||||||
|
r := new(regression.Regression)
|
||||||
|
r.SetObserved("y")
|
||||||
|
r.SetVar(0, "x")
|
||||||
|
|
||||||
|
for _, p := range points {
|
||||||
|
r.Train(regression.DataPoint(p.y, []float64{p.x}))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.Run(); err != nil {
|
||||||
|
log.Printf("线性回归计算失败: %v", err)
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "计算失败",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建结果
|
||||||
|
slope := r.Coeff(0)
|
||||||
|
intercept := r.Coeff(1)
|
||||||
|
r2 := r.R2
|
||||||
|
return models.RegressionResult{
|
||||||
|
RegressionType: models.LinearRegression,
|
||||||
|
Slope: &slope,
|
||||||
|
Intercept: &intercept,
|
||||||
|
RSquared: &r2,
|
||||||
|
Equation: r.Formula,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现对数和二次回归算法
|
||||||
|
// 对数回归算法
|
||||||
|
func performLogarithmicRegression(averages []map[float64]float64) models.RegressionResult {
|
||||||
|
if len(averages) == 0 {
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "无数据",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集数据点
|
||||||
|
r := new(regression.Regression)
|
||||||
|
r.SetObserved("y")
|
||||||
|
r.SetVar(0, "log(x+1)")
|
||||||
|
|
||||||
|
for _, m := range averages {
|
||||||
|
for speed, hr := range m {
|
||||||
|
logSpeed := math.Log(speed + 1)
|
||||||
|
r.Train(regression.DataPoint(hr, []float64{logSpeed}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.Run(); err != nil {
|
||||||
|
log.Printf("对数回归计算失败: %v", err)
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "计算失败",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建结果
|
||||||
|
logA := r.Coeff(1)
|
||||||
|
logB := r.Coeff(0)
|
||||||
|
r2 := r.R2
|
||||||
|
return models.RegressionResult{
|
||||||
|
RegressionType: models.LogarithmicRegression,
|
||||||
|
LogA: &logA,
|
||||||
|
LogB: &logB,
|
||||||
|
RSquared: &r2,
|
||||||
|
Equation: r.Formula,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 二次回归算法
|
||||||
|
//func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
|
||||||
|
// if len(averages) == 0 {
|
||||||
|
// return models.RegressionResult{
|
||||||
|
// Equation: "无数据",
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // 收集数据点
|
||||||
|
// r := new(regression.Regression)
|
||||||
|
// r.SetObserved("y")
|
||||||
|
// r.SetVar(0, "x")
|
||||||
|
// r.SetVar(1, "x²")
|
||||||
|
//
|
||||||
|
// for _, m := range averages {
|
||||||
|
// for speed, hr := range m {
|
||||||
|
// speedSq := math.Pow(speed, 2)
|
||||||
|
// r.Train(regression.DataPoint(hr, []float64{speed, speedSq}))
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// if err := r.Run(); err != nil {
|
||||||
|
// log.Printf("二次回归计算失败: %v", err)
|
||||||
|
// return models.RegressionResult{
|
||||||
|
// Equation: "计算失败",
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // 创建结果
|
||||||
|
// a := r.Coeff(2)
|
||||||
|
// b := r.Coeff(1)
|
||||||
|
// c := r.Coeff(0)
|
||||||
|
// r2 := r.R2
|
||||||
|
// return models.RegressionResult{
|
||||||
|
// RegressionType: models.QuadraticRegression,
|
||||||
|
// QuadraticA: &a,
|
||||||
|
// QuadraticB: &b,
|
||||||
|
// QuadraticC: &c,
|
||||||
|
// RSquared: &r2,
|
||||||
|
// Equation: r.Formula,
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
|
||||||
|
if len(averages) == 0 {
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "无数据",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤1:收集所有数据点(与Flutter一致)
|
||||||
|
var xValues []float64
|
||||||
|
var yValues []float64
|
||||||
|
for _, m := range averages {
|
||||||
|
for speed, hr := range m {
|
||||||
|
xValues = append(xValues, speed)
|
||||||
|
yValues = append(yValues, hr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n := float64(len(xValues))
|
||||||
|
|
||||||
|
// 步骤2:计算各项和(完全匹配Flutter的计算)
|
||||||
|
var sumX, sumY, sumX2, sumX3, sumX4, sumXY, sumX2Y float64
|
||||||
|
for i := 0; i < len(xValues); i++ {
|
||||||
|
x := xValues[i]
|
||||||
|
y := yValues[i]
|
||||||
|
x2 := x * x
|
||||||
|
x3 := x2 * x
|
||||||
|
x4 := x3 * x
|
||||||
|
|
||||||
|
sumX += x
|
||||||
|
sumY += y
|
||||||
|
sumX2 += x2
|
||||||
|
sumX3 += x3
|
||||||
|
sumX4 += x4
|
||||||
|
sumXY += x * y
|
||||||
|
sumX2Y += x2 * y
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤3:构建正规方程矩阵(与Flutter完全一致)
|
||||||
|
matrix := [3][3]float64{
|
||||||
|
{n, sumX, sumX2},
|
||||||
|
{sumX, sumX2, sumX3},
|
||||||
|
{sumX2, sumX3, sumX4},
|
||||||
|
}
|
||||||
|
vector := []float64{sumY, sumXY, sumX2Y}
|
||||||
|
|
||||||
|
// 步骤4:计算矩阵行列式(复制Flutter的determinant3x3逻辑)
|
||||||
|
det := matrix[0][0]*(matrix[1][1]*matrix[2][2]-matrix[1][2]*matrix[2][1]) -
|
||||||
|
matrix[0][1]*(matrix[1][0]*matrix[2][2]-matrix[1][2]*matrix[2][0]) +
|
||||||
|
matrix[0][2]*(matrix[1][0]*matrix[2][1]-matrix[1][1]*matrix[2][0])
|
||||||
|
|
||||||
|
if det == 0 {
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "无法拟合",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤5:克莱姆法则求解系数(顺序与Flutter一致)
|
||||||
|
// 注意:最终系数顺序 a=二次项, b=一次项, c=常数项
|
||||||
|
c := det3x3([3][3]float64{
|
||||||
|
{vector[0], matrix[0][1], matrix[0][2]},
|
||||||
|
{vector[1], matrix[1][1], matrix[1][2]},
|
||||||
|
{vector[2], matrix[2][1], matrix[2][2]},
|
||||||
|
}) / det
|
||||||
|
|
||||||
|
b := det3x3([3][3]float64{
|
||||||
|
{matrix[0][0], vector[0], matrix[0][2]},
|
||||||
|
{matrix[1][0], vector[1], matrix[1][2]},
|
||||||
|
{matrix[2][0], vector[2], matrix[2][2]},
|
||||||
|
}) / det
|
||||||
|
|
||||||
|
a := det3x3([3][3]float64{
|
||||||
|
{matrix[0][0], matrix[0][1], vector[0]},
|
||||||
|
{matrix[1][0], matrix[1][1], vector[1]},
|
||||||
|
{matrix[2][0], matrix[2][1], vector[2]},
|
||||||
|
}) / det
|
||||||
|
|
||||||
|
// 步骤6:计算R平方(完全复制Flutter的计算逻辑)
|
||||||
|
var ssRes, ssTot float64
|
||||||
|
meanY := sumY / n
|
||||||
|
for i := 0; i < len(xValues); i++ {
|
||||||
|
x := xValues[i]
|
||||||
|
y := yValues[i]
|
||||||
|
yPred := a*x*x + b*x + c
|
||||||
|
ssRes += math.Pow(y-yPred, 2)
|
||||||
|
ssTot += math.Pow(y-meanY, 2)
|
||||||
|
}
|
||||||
|
rSquared := 0.0
|
||||||
|
if ssTot != 0 {
|
||||||
|
rSquared = 1 - ssRes/ssTot
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤7:格式化公式字符串(与Flutter格式完全一致)
|
||||||
|
equation := formatEquation(a, b, c, rSquared)
|
||||||
|
|
||||||
|
return models.RegressionResult{
|
||||||
|
RegressionType: models.QuadraticRegression,
|
||||||
|
QuadraticA: &a,
|
||||||
|
QuadraticB: &b,
|
||||||
|
QuadraticC: &c,
|
||||||
|
RSquared: &rSquared,
|
||||||
|
Equation: equation,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3x3行列式计算(与Flutter实现相同)
|
||||||
|
func det3x3(m [3][3]float64) float64 {
|
||||||
|
return m[0][0]*(m[1][1]*m[2][2]-m[1][2]*m[2][1]) -
|
||||||
|
m[0][1]*(m[1][0]*m[2][2]-m[1][2]*m[2][0]) +
|
||||||
|
m[0][2]*(m[1][0]*m[2][1]-m[1][1]*m[2][0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 公式格式化(完全匹配Flutter格式)
|
||||||
|
func formatEquation(a, b, c, r2 float64) string {
|
||||||
|
// 保留4位小数
|
||||||
|
aStr := fmt.Sprintf("%.4f", a)
|
||||||
|
bStr := fmt.Sprintf("%.4f", b)
|
||||||
|
cStr := fmt.Sprintf("%.4f", c)
|
||||||
|
r2Str := fmt.Sprintf("%.4f", r2)
|
||||||
|
|
||||||
|
builder := strings.Builder{}
|
||||||
|
builder.WriteString("y = ")
|
||||||
|
|
||||||
|
// 处理二次项
|
||||||
|
if a >= 0 {
|
||||||
|
builder.WriteString(aStr + " x²")
|
||||||
|
} else {
|
||||||
|
builder.WriteString("-" + strings.TrimPrefix(aStr, "-") + " x²")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理一次项
|
||||||
|
if b >= 0 {
|
||||||
|
builder.WriteString(" + " + bStr + " x")
|
||||||
|
} else {
|
||||||
|
builder.WriteString(" - " + strings.TrimPrefix(bStr, "-") + " x")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理常数项
|
||||||
|
if c >= 0 {
|
||||||
|
builder.WriteString(" + " + cStr)
|
||||||
|
} else {
|
||||||
|
builder.WriteString(" - " + strings.TrimPrefix(cStr, "-"))
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.WriteString(" (R² = " + r2Str + ")")
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步频数据转换为速度段
|
||||||
|
func convertStrideFrequencyToSegments(steps []models.StepStrideFreq) []SpeedSegment {
|
||||||
|
if len(steps) == 0 {
|
||||||
|
return []SpeedSegment{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 过滤零值并排序
|
||||||
|
validSteps := make([]models.StepStrideFreq, 0, len(steps))
|
||||||
|
for _, s := range steps {
|
||||||
|
if s.Value > 0 {
|
||||||
|
validSteps = append(validSteps, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validSteps) == 0 {
|
||||||
|
return []SpeedSegment{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按时间排序
|
||||||
|
for i := 0; i < len(validSteps)-1; i++ {
|
||||||
|
for j := i + 1; j < len(validSteps); j++ {
|
||||||
|
if validSteps[i].Time > validSteps[j].Time {
|
||||||
|
validSteps[i], validSteps[j] = validSteps[j], validSteps[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建速度段
|
||||||
|
segments := make([]SpeedSegment, 0)
|
||||||
|
startTime := validSteps[0].Time
|
||||||
|
currentValue := validSteps[0].Value
|
||||||
|
|
||||||
|
for i := 1; i < len(validSteps); i++ {
|
||||||
|
if validSteps[i].Value != currentValue {
|
||||||
|
duration := float64(validSteps[i].Time-startTime) / 1000.0
|
||||||
|
if duration > 0 {
|
||||||
|
segments = append(segments, SpeedSegment{
|
||||||
|
Duration: duration,
|
||||||
|
Speed: float64(currentValue),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
startTime = validSteps[i].Time
|
||||||
|
currentValue = validSteps[i].Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加最后一个段
|
||||||
|
if len(validSteps) > 0 {
|
||||||
|
duration := float64(validSteps[len(validSteps)-1].Time-startTime) / 1000.0
|
||||||
|
if duration > 0 {
|
||||||
|
segments = append(segments, SpeedSegment{
|
||||||
|
Duration: duration,
|
||||||
|
Speed: float64(currentValue),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return segments
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算区段平均值
|
||||||
|
func calculateSegmentAverages(heartRates []models.StepHeartRate, segments []SpeedSegment, errorThreshold int) []map[float64]float64 {
|
||||||
|
currentTime := 0.0
|
||||||
|
results := make([]map[float64]float64, 0)
|
||||||
|
|
||||||
|
for _, seg := range segments {
|
||||||
|
minRequired := 60 + (60 - float64(errorThreshold))
|
||||||
|
|
||||||
|
// 跳过不满足条件的区段
|
||||||
|
if seg.Duration < minRequired {
|
||||||
|
currentTime += seg.Duration
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算时间窗口
|
||||||
|
startSec := currentTime + 60
|
||||||
|
endSec := currentTime
|
||||||
|
if seg.Duration >= 120 {
|
||||||
|
endSec = currentTime + 120
|
||||||
|
} else {
|
||||||
|
endSec = currentTime + 120 - float64(errorThreshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集该区段的心率数据
|
||||||
|
sum, count := 0, 0
|
||||||
|
for _, hr := range heartRates {
|
||||||
|
sec := float64(hr.Time) / 1000.0
|
||||||
|
if sec >= startSec && sec <= endSec {
|
||||||
|
sum += hr.Value
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算平均值
|
||||||
|
if count > 0 {
|
||||||
|
avg := float64(sum) / float64(count)
|
||||||
|
results = append(results, map[float64]float64{seg.Speed: avg})
|
||||||
|
}
|
||||||
|
|
||||||
|
currentTime += seg.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算步频区段的心率平均值
|
||||||
|
func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps []models.StepStrideFreq) []map[float64]float64 {
|
||||||
|
segments := convertStrideFrequencyToSegments(steps)
|
||||||
|
return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值
|
||||||
|
}
|
||||||
|
|
||||||
|
// 存储回归结果到数据库
|
||||||
|
func (tc *StepTrainingController) SaveRegressionResult(trainId uint, result models.RegressionResult) error {
|
||||||
|
result.TrainId = trainId
|
||||||
|
|
||||||
|
return tc.DB.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "id"}},
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||||
|
"equation": result.Equation,
|
||||||
|
"slope": result.Slope,
|
||||||
|
"intercept": result.Intercept,
|
||||||
|
"log_a": result.LogA,
|
||||||
|
"log_b": result.LogB,
|
||||||
|
"quadratic_a": result.QuadraticA,
|
||||||
|
"quadratic_b": result.QuadraticB,
|
||||||
|
"quadratic_c": result.QuadraticC,
|
||||||
|
"r_squared": result.RSquared,
|
||||||
|
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||||
|
}),
|
||||||
|
}).Create(&result).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取或计算回归结果
|
||||||
|
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) (models.RegressionResult, error) {
|
||||||
|
// 首先尝试从数据库获取
|
||||||
|
var result models.RegressionResult
|
||||||
|
err := tc.DB.Where("train_id = ?", trainId).First(&result).Error
|
||||||
|
|
||||||
|
// 如果找到记录,直接返回
|
||||||
|
if err == nil {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果错误不是记录不存在,返回错误
|
||||||
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return models.RegressionResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询训练记录及相关数据
|
||||||
|
var record models.StepTrainRecord
|
||||||
|
if err := tc.DB.
|
||||||
|
Where("train_id = ?", uint(trainId)).
|
||||||
|
Preload("HeartRates", "heart_rate_type = ?", 1).
|
||||||
|
Preload("StrideFreqs", "predict_value = ?", 1).
|
||||||
|
First(&record).Error; err != nil {
|
||||||
|
return models.RegressionResult{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算心率平均值(模仿Flutter的calculateSegmentAveragesByRealStep)
|
||||||
|
averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs)
|
||||||
|
if len(averages) == 0 {
|
||||||
|
return models.RegressionResult{}, errors.New("无足够数据进行回归计算")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算三种回归
|
||||||
|
result = models.RegressionResult{
|
||||||
|
TrainId: trainId,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 线性回归
|
||||||
|
linearRes := performLinearRegression(averages)
|
||||||
|
result.Slope = linearRes.Slope
|
||||||
|
result.Intercept = linearRes.Intercept
|
||||||
|
result.RSquared = linearRes.RSquared
|
||||||
|
result.Equation = "线性回归: " + linearRes.Equation
|
||||||
|
|
||||||
|
// 对数回归
|
||||||
|
logRes := performLogarithmicRegression(averages)
|
||||||
|
result.LogA = logRes.LogA
|
||||||
|
result.LogB = logRes.LogB
|
||||||
|
if result.Equation != "" {
|
||||||
|
result.Equation += "\n"
|
||||||
|
}
|
||||||
|
result.Equation += "对数回归: " + logRes.Equation
|
||||||
|
|
||||||
|
// 二次回归
|
||||||
|
quadRes := performQuadraticRegression(averages)
|
||||||
|
result.QuadraticA = quadRes.QuadraticA
|
||||||
|
result.QuadraticB = quadRes.QuadraticB
|
||||||
|
result.QuadraticC = quadRes.QuadraticC
|
||||||
|
if result.Equation != "" {
|
||||||
|
result.Equation += "\n"
|
||||||
|
}
|
||||||
|
result.Equation += "二次回归: " + quadRes.Equation
|
||||||
|
|
||||||
|
// 保存计算结果到数据库
|
||||||
|
if err := tc.SaveRegressionResult(trainId, result); err != nil {
|
||||||
|
log.Printf("保存回归结果失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 新增接口:获取回归结果
|
||||||
|
func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) {
|
||||||
|
trainIdStr := c.Param("trainId")
|
||||||
|
tid, err := strconv.ParseUint(trainIdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := tc.GetOrCalculateRegression(uint(tid))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "获取成功",
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取训练记录的排名
|
||||||
|
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
|
||||||
|
// 解析参数
|
||||||
|
trainIdStr := c.Param("trainId")
|
||||||
|
|
||||||
|
regressionTypeStr := c.Query("type")
|
||||||
|
regressionType, err := strconv.Atoi(regressionTypeStr) // 字符串转整型
|
||||||
|
if err != nil {
|
||||||
|
// 转换失败时返回400错误
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
regType := models.RegressionType(regressionType)
|
||||||
|
|
||||||
|
// 验证回归类型
|
||||||
|
if regType != models.LinearRegression && regType != models.QuadraticRegression {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换trainId
|
||||||
|
tid, err := strconv.ParseUint(trainIdStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
trainId := uint(tid)
|
||||||
|
|
||||||
|
// 确保回归结果存在
|
||||||
|
if _, err := tc.GetOrCalculateRegression(trainId); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取所有记录用于排名
|
||||||
|
var records []models.RegressionResult
|
||||||
|
query := tc.DB.Model(&models.RegressionResult{})
|
||||||
|
switch regType {
|
||||||
|
case models.LinearRegression:
|
||||||
|
query = query.Where("slope IS NOT NULL").Select("train_id, slope AS metric")
|
||||||
|
case models.QuadraticRegression:
|
||||||
|
query = query.Where("quadratic_a IS NOT NULL").
|
||||||
|
Select("train_id, ABS(quadratic_a) AS metric")
|
||||||
|
}
|
||||||
|
if err := query.Find(&records).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理无数据情况
|
||||||
|
if len(records) == 0 {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 排序逻辑
|
||||||
|
sort.Slice(records, func(i, j int) bool {
|
||||||
|
if regType == models.LinearRegression {
|
||||||
|
return *records[i].Slope < *records[j].Slope
|
||||||
|
}
|
||||||
|
return math.Abs(*records[i].QuadraticA) > math.Abs(*records[j].QuadraticA) // 二次回归按绝对值降序
|
||||||
|
})
|
||||||
|
|
||||||
|
// 计算排名(处理并列)
|
||||||
|
currentRank := 1
|
||||||
|
rankMap := make(map[uint]int)
|
||||||
|
for i, record := range records {
|
||||||
|
// 处理第一个记录
|
||||||
|
if i == 0 {
|
||||||
|
rankMap[record.TrainId] = currentRank
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if regType == models.LinearRegression {
|
||||||
|
if records[i].Slope != records[i-1].Slope {
|
||||||
|
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1
|
||||||
|
}
|
||||||
|
rankMap[record.TrainId] = currentRank
|
||||||
|
} else {
|
||||||
|
if records[i].QuadraticA != records[i-1].QuadraticA {
|
||||||
|
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1
|
||||||
|
}
|
||||||
|
rankMap[record.TrainId] = currentRank
|
||||||
|
}
|
||||||
|
// 检测值是否变化
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前训练记录的排名
|
||||||
|
rank, exists := rankMap[trainId]
|
||||||
|
if !exists {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回响应
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "排名查询成功",
|
||||||
|
"data": gin.H{
|
||||||
|
"trainId": trainId,
|
||||||
|
"type": regressionType,
|
||||||
|
"rank": rank,
|
||||||
|
"total": len(records),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
1
go.mod
1
go.mod
@ -5,6 +5,7 @@ go 1.23.3
|
|||||||
require (
|
require (
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||||
|
github.com/sajari/regression v1.0.1
|
||||||
github.com/spf13/viper v1.20.0
|
github.com/spf13/viper v1.20.0
|
||||||
gonum.org/v1/gonum v0.16.0
|
gonum.org/v1/gonum v0.16.0
|
||||||
gorm.io/driver/postgres v1.5.11
|
gorm.io/driver/postgres v1.5.11
|
||||||
|
|||||||
2
go.sum
2
go.sum
@ -75,6 +75,8 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
|||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
|
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
|
||||||
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
|
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
|
||||||
|
github.com/sajari/regression v1.0.1 h1:iTVc6ZACGCkoXC+8NdqH5tIreslDTT/bXxT6OmHR5PE=
|
||||||
|
github.com/sajari/regression v1.0.1/go.mod h1:NeG/XTW1lYfGY7YV/Z0nYDV/RGh3wxwd1yW46835flM=
|
||||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||||
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
|
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
|
||||||
|
|||||||
6
main.go
6
main.go
@ -22,9 +22,11 @@ func main() {
|
|||||||
&models.BeltAnalysis{},
|
&models.BeltAnalysis{},
|
||||||
&models.StepTrainRecord{},
|
&models.StepTrainRecord{},
|
||||||
&models.StepHeartRate{},
|
&models.StepHeartRate{},
|
||||||
&models.StepStrideFreq{})
|
&models.StepStrideFreq{},
|
||||||
|
&models.RegressionResult{},
|
||||||
|
)
|
||||||
|
|
||||||
// 启动服务
|
// 启动服务
|
||||||
r := routes.SetupRouter()
|
r := routes.SetupRouter()
|
||||||
r.Run(":8080")
|
r.Run(":8000")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -36,3 +36,33 @@ type StepTrainRecord struct {
|
|||||||
HeartRates []StepHeartRate `gorm:"foreignKey:TrainId;references:TrainId" json:"heartRates"`
|
HeartRates []StepHeartRate `gorm:"foreignKey:TrainId;references:TrainId" json:"heartRates"`
|
||||||
StrideFreqs []StepStrideFreq `gorm:"foreignKey:TrainId;references:TrainId" json:"strideFreqs"`
|
StrideFreqs []StepStrideFreq `gorm:"foreignKey:TrainId;references:TrainId" json:"strideFreqs"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RegressionType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
LinearRegression RegressionType = iota + 1
|
||||||
|
LogarithmicRegression
|
||||||
|
QuadraticRegression
|
||||||
|
)
|
||||||
|
|
||||||
|
type RegressionResult struct {
|
||||||
|
gorm.Model
|
||||||
|
RegressionType RegressionType `gorm:"column:regression_type;index" json:"regressionType"` // 训练记录ID
|
||||||
|
TrainId uint `gorm:"column:train_id;index" json:"trainId"` // 训练记录ID
|
||||||
|
Equation string `gorm:"type:text" json:"equation"` // 回归方程
|
||||||
|
|
||||||
|
// 线性回归系数
|
||||||
|
Slope *float64 `gorm:"column:slope" json:"slope"`
|
||||||
|
Intercept *float64 `gorm:"column:intercept" json:"intercept"`
|
||||||
|
|
||||||
|
// 对数回归系数
|
||||||
|
LogA *float64 `gorm:"column:log_a" json:"logA"`
|
||||||
|
LogB *float64 `gorm:"column:log_b" json:"logB"`
|
||||||
|
|
||||||
|
// 二次回归系数
|
||||||
|
QuadraticA *float64 `gorm:"column:quadratic_a" json:"quadraticA"`
|
||||||
|
QuadraticB *float64 `gorm:"column:quadratic_b" json:"quadraticB"`
|
||||||
|
QuadraticC *float64 `gorm:"column:quadratic_c" json:"quadraticC"`
|
||||||
|
|
||||||
|
RSquared *float64 `gorm:"column:r_squared" json:"rSquared"` // R平方值
|
||||||
|
}
|
||||||
|
|||||||
@ -27,6 +27,7 @@ func SetupRouter() *gin.Engine {
|
|||||||
steps.POST("", stepTrainController.CreateTrainingRecord)
|
steps.POST("", stepTrainController.CreateTrainingRecord)
|
||||||
steps.GET("train-records", stepTrainController.GetTrainingRecords)
|
steps.GET("train-records", stepTrainController.GetTrainingRecords)
|
||||||
steps.GET("train-data/:trainId", stepTrainController.GetTrainingRecordByTrainId)
|
steps.GET("train-data/:trainId", stepTrainController.GetTrainingRecordByTrainId)
|
||||||
|
steps.GET("train-rank/:trainId", stepTrainController.GetTrainingRank)
|
||||||
// 可扩展其他路由:GET, PUT, DELETE等
|
// 可扩展其他路由:GET, PUT, DELETE等
|
||||||
}
|
}
|
||||||
auth := v1.Group("/auth")
|
auth := v1.Group("/auth")
|
||||||
|
|||||||
Reference in New Issue
Block a user