Files
hr_data_analyzer/controllers/step_train.go
2025-08-05 09:55:37 +08:00

830 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package controllers
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/sajari/regression"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"hr_receiver/config"
"hr_receiver/models"
"log"
"math"
"net/http"
"sort"
"strconv"
"strings"
)
type StepTrainingController struct {
DB *gorm.DB
}
func NewStepTrainingController() *StepTrainingController {
return &StepTrainingController{DB: config.DB}
}
// 接收训练记录
func (tc *StepTrainingController) CreateTrainingRecord(c *gin.Context) {
var record models.StepTrainRecord
// 绑定并验证JSON数据
if err := c.ShouldBindJSON(&record); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 使用事务保存数据[4](@ref)
err := tc.DB.Transaction(func(tx *gorm.DB) error {
// 保存主记录
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{
"max_heart_rate": record.MaxHeartRate,
"start_time": record.StartTime,
"end_time": record.EndTime,
"duration": record.Duration,
"dead_zone": record.DeadZone,
"name": record.Name,
"evaluation": record.Evaluation,
}),
}).Omit("HeartRates", "StrideFreqs").Create(&record).Error; err != nil {
return err
}
// 保存关联的心率数据
for i := range record.HeartRates {
if err := tx.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{"heart_rate_type": record.HeartRates[i].HeartRateType, "value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
},
).Create(&record.HeartRates[i]).Error; err != nil {
return err
}
}
for i := range record.StrideFreqs {
if err := tx.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.StrideFreqs[i].Value, "time": record.StrideFreqs[i].Time}),
},
).Create(&record.StrideFreqs[i]).Error; err != nil {
return err
}
}
return nil
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "数据保存成功",
"id": record.TrainId,
})
}
func (tc *StepTrainingController) GetTrainingRecords(c *gin.Context) {
// 定义分页参数结构
type PaginationParams struct {
PageNum int `form:"pageNum,default=1"` // 页码,默认第一页
PageSize int `form:"pageSize,default=10"` // 每页数量默认10条
}
var params PaginationParams
if err := c.ShouldBindQuery(&params); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 验证分页参数有效性
if params.PageNum < 1 {
params.PageNum = 1
}
if params.PageSize < 1 || params.PageSize > 100 {
params.PageSize = 10
}
// 计算偏移量
offset := (params.PageNum - 1) * params.PageSize
var (
records []models.StepTrainRecord
totalRows int64
)
// 获取总记录数
if err := tc.DB.Model(&models.StepTrainRecord{}).Count(&totalRows).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取记录总数失败"})
return
}
// 查询分页数据(按开始时间倒序排列)
result := tc.DB.
Order("start_time DESC"). // 按开始时间倒序
Offset(offset).
Limit(params.PageSize).
Find(&records)
if result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
return
}
// 计算总页数
totalPages := int(math.Ceil(float64(totalRows) / float64(params.PageSize)))
c.JSON(http.StatusOK, gin.H{
"message": "查询成功",
"data": gin.H{
"list": records,
"pagination": gin.H{
"currentPage": params.PageNum,
"pageSize": params.PageSize,
"totalPage": totalPages,
"totalList": totalRows,
},
},
})
}
func (tc *StepTrainingController) GetTrainingRecordByTrainId(c *gin.Context) {
// 从URL路径参数获取trainId
trainId := c.Param("trainId")
if trainId == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "训练ID不能为空"})
return
}
// 将字符串trainId转换为uint类型
tid, err := strconv.ParseInt(trainId, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID格式"})
return
}
var record models.StepTrainRecord
// 查询主记录并预加载关联的心率和步频数据
result := tc.DB.Where("train_id = ?", uint(tid)).
Preload("HeartRates").
Preload("StrideFreqs").
First(&record)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": "训练记录不存在"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
}
return
}
// 成功返回数据
c.JSON(http.StatusOK, gin.H{
"message": "查询成功",
"data": record,
})
}
// 定义结构体
type SpeedSegment struct {
Duration float64
Speed float64
}
// 实现线性回归算法
func performLinearRegression(averages []map[float64]float64) models.RegressionResult {
if len(averages) == 0 {
return models.RegressionResult{
Equation: "无数据",
}
}
// 收集数据点
var points []struct{ x, y float64 }
for _, m := range averages {
for x, y := range m {
points = append(points, struct{ x, y float64 }{x, y})
}
}
// 使用回归库计算
r := new(regression.Regression)
r.SetObserved("y")
r.SetVar(0, "x")
for _, p := range points {
r.Train(regression.DataPoint(p.y, []float64{p.x}))
}
if err := r.Run(); err != nil {
log.Printf("线性回归计算失败: %v", err)
return models.RegressionResult{
Equation: "计算失败",
}
}
// 创建结果
slope := r.Coeff(0)
intercept := r.Coeff(1)
r2 := r.R2
return models.RegressionResult{
RegressionType: models.LinearRegression,
Slope: &slope,
Intercept: &intercept,
RSquared: &r2,
Equation: r.Formula,
}
}
// 实现对数和二次回归算法
// 对数回归算法
func performLogarithmicRegression(averages []map[float64]float64) models.RegressionResult {
if len(averages) == 0 {
return models.RegressionResult{
Equation: "无数据",
}
}
// 收集数据点
r := new(regression.Regression)
r.SetObserved("y")
r.SetVar(0, "log(x+1)")
for _, m := range averages {
for speed, hr := range m {
logSpeed := math.Log(speed + 1)
r.Train(regression.DataPoint(hr, []float64{logSpeed}))
}
}
if err := r.Run(); err != nil {
log.Printf("对数回归计算失败: %v", err)
return models.RegressionResult{
Equation: "计算失败",
}
}
// 创建结果
logA := r.Coeff(1)
logB := r.Coeff(0)
r2 := r.R2
return models.RegressionResult{
RegressionType: models.LogarithmicRegression,
LogA: &logA,
LogB: &logB,
RSquared: &r2,
Equation: r.Formula,
}
}
// 二次回归算法
//func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
// if len(averages) == 0 {
// return models.RegressionResult{
// Equation: "无数据",
// }
// }
//
// // 收集数据点
// r := new(regression.Regression)
// r.SetObserved("y")
// r.SetVar(0, "x")
// r.SetVar(1, "x²")
//
// for _, m := range averages {
// for speed, hr := range m {
// speedSq := math.Pow(speed, 2)
// r.Train(regression.DataPoint(hr, []float64{speed, speedSq}))
// }
// }
//
// if err := r.Run(); err != nil {
// log.Printf("二次回归计算失败: %v", err)
// return models.RegressionResult{
// Equation: "计算失败",
// }
// }
//
// // 创建结果
// a := r.Coeff(2)
// b := r.Coeff(1)
// c := r.Coeff(0)
// r2 := r.R2
// return models.RegressionResult{
// RegressionType: models.QuadraticRegression,
// QuadraticA: &a,
// QuadraticB: &b,
// QuadraticC: &c,
// RSquared: &r2,
// Equation: r.Formula,
// }
//}
func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
if len(averages) == 0 {
return models.RegressionResult{
Equation: "无数据",
}
}
// 步骤1收集所有数据点与Flutter一致
var xValues []float64
var yValues []float64
for _, m := range averages {
for speed, hr := range m {
xValues = append(xValues, speed)
yValues = append(yValues, hr)
}
}
n := float64(len(xValues))
// 步骤2计算各项和完全匹配Flutter的计算
var sumX, sumY, sumX2, sumX3, sumX4, sumXY, sumX2Y float64
for i := 0; i < len(xValues); i++ {
x := xValues[i]
y := yValues[i]
x2 := x * x
x3 := x2 * x
x4 := x3 * x
sumX += x
sumY += y
sumX2 += x2
sumX3 += x3
sumX4 += x4
sumXY += x * y
sumX2Y += x2 * y
}
// 步骤3构建正规方程矩阵与Flutter完全一致
matrix := [3][3]float64{
{n, sumX, sumX2},
{sumX, sumX2, sumX3},
{sumX2, sumX3, sumX4},
}
vector := []float64{sumY, sumXY, sumX2Y}
// 步骤4计算矩阵行列式复制Flutter的determinant3x3逻辑
det := matrix[0][0]*(matrix[1][1]*matrix[2][2]-matrix[1][2]*matrix[2][1]) -
matrix[0][1]*(matrix[1][0]*matrix[2][2]-matrix[1][2]*matrix[2][0]) +
matrix[0][2]*(matrix[1][0]*matrix[2][1]-matrix[1][1]*matrix[2][0])
if det == 0 {
return models.RegressionResult{
Equation: "无法拟合",
}
}
// 步骤5克莱姆法则求解系数顺序与Flutter一致
// 注意:最终系数顺序 a=二次项, b=一次项, c=常数项
c := det3x3([3][3]float64{
{vector[0], matrix[0][1], matrix[0][2]},
{vector[1], matrix[1][1], matrix[1][2]},
{vector[2], matrix[2][1], matrix[2][2]},
}) / det
b := det3x3([3][3]float64{
{matrix[0][0], vector[0], matrix[0][2]},
{matrix[1][0], vector[1], matrix[1][2]},
{matrix[2][0], vector[2], matrix[2][2]},
}) / det
a := det3x3([3][3]float64{
{matrix[0][0], matrix[0][1], vector[0]},
{matrix[1][0], matrix[1][1], vector[1]},
{matrix[2][0], matrix[2][1], vector[2]},
}) / det
// 步骤6计算R平方完全复制Flutter的计算逻辑
var ssRes, ssTot float64
meanY := sumY / n
for i := 0; i < len(xValues); i++ {
x := xValues[i]
y := yValues[i]
yPred := a*x*x + b*x + c
ssRes += math.Pow(y-yPred, 2)
ssTot += math.Pow(y-meanY, 2)
}
rSquared := 0.0
if ssTot != 0 {
rSquared = 1 - ssRes/ssTot
}
// 步骤7格式化公式字符串与Flutter格式完全一致
equation := formatEquation(a, b, c, rSquared)
return models.RegressionResult{
RegressionType: models.QuadraticRegression,
QuadraticA: &a,
QuadraticB: &b,
QuadraticC: &c,
RSquared: &rSquared,
Equation: equation,
}
}
// 3x3行列式计算与Flutter实现相同
func det3x3(m [3][3]float64) float64 {
return m[0][0]*(m[1][1]*m[2][2]-m[1][2]*m[2][1]) -
m[0][1]*(m[1][0]*m[2][2]-m[1][2]*m[2][0]) +
m[0][2]*(m[1][0]*m[2][1]-m[1][1]*m[2][0])
}
// 公式格式化完全匹配Flutter格式
func formatEquation(a, b, c, r2 float64) string {
// 保留4位小数
aStr := fmt.Sprintf("%.4f", a)
bStr := fmt.Sprintf("%.4f", b)
cStr := fmt.Sprintf("%.4f", c)
r2Str := fmt.Sprintf("%.4f", r2)
builder := strings.Builder{}
builder.WriteString("y = ")
// 处理二次项
if a >= 0 {
builder.WriteString(aStr + " x²")
} else {
builder.WriteString("-" + strings.TrimPrefix(aStr, "-") + " x²")
}
// 处理一次项
if b >= 0 {
builder.WriteString(" + " + bStr + " x")
} else {
builder.WriteString(" - " + strings.TrimPrefix(bStr, "-") + " x")
}
// 处理常数项
if c >= 0 {
builder.WriteString(" + " + cStr)
} else {
builder.WriteString(" - " + strings.TrimPrefix(cStr, "-"))
}
builder.WriteString(" (R² = " + r2Str + ")")
return builder.String()
}
// 步频数据转换为速度段
func convertStrideFrequencyToSegments(steps []models.StepStrideFreq) []SpeedSegment {
if len(steps) == 0 {
return []SpeedSegment{}
}
// 过滤零值并排序
validSteps := make([]models.StepStrideFreq, 0, len(steps))
for _, s := range steps {
if s.Value > 0 {
validSteps = append(validSteps, s)
}
}
if len(validSteps) == 0 {
return []SpeedSegment{}
}
// 按时间排序
for i := 0; i < len(validSteps)-1; i++ {
for j := i + 1; j < len(validSteps); j++ {
if validSteps[i].Time > validSteps[j].Time {
validSteps[i], validSteps[j] = validSteps[j], validSteps[i]
}
}
}
// 创建速度段
segments := make([]SpeedSegment, 0)
startTime := validSteps[0].Time
currentValue := validSteps[0].Value
for i := 1; i < len(validSteps); i++ {
if validSteps[i].Value != currentValue {
duration := float64(validSteps[i].Time-startTime) / 1000.0
if duration > 0 {
segments = append(segments, SpeedSegment{
Duration: duration,
Speed: float64(currentValue),
})
}
startTime = validSteps[i].Time
currentValue = validSteps[i].Value
}
}
// 添加最后一个段
if len(validSteps) > 0 {
duration := float64(validSteps[len(validSteps)-1].Time-startTime) / 1000.0
if duration > 0 {
segments = append(segments, SpeedSegment{
Duration: duration,
Speed: float64(currentValue),
})
}
}
return segments
}
// 计算区段平均值
func calculateSegmentAverages(heartRates []models.StepHeartRate, segments []SpeedSegment, errorThreshold int) []map[float64]float64 {
currentTime := 0.0
results := make([]map[float64]float64, 0)
for _, seg := range segments {
minRequired := 60 + (60 - float64(errorThreshold))
// 跳过不满足条件的区段
if seg.Duration < minRequired {
currentTime += seg.Duration
continue
}
// 计算时间窗口
startSec := currentTime + 60
endSec := currentTime
if seg.Duration >= 120 {
endSec = currentTime + 120
} else {
endSec = currentTime + 120 - float64(errorThreshold)
}
// 收集该区段的心率数据
sum, count := 0, 0
for _, hr := range heartRates {
sec := float64(hr.Time) / 1000.0
if sec >= startSec && sec <= endSec {
sum += hr.Value
count++
}
}
// 计算平均值
if count > 0 {
avg := float64(sum) / float64(count)
results = append(results, map[float64]float64{seg.Speed: avg})
}
currentTime += seg.Duration
}
return results
}
// 计算步频区段的心率平均值
func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps []models.StepStrideFreq) []map[float64]float64 {
segments := convertStrideFrequencyToSegments(steps)
return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值
}
// 存储回归结果到数据库(支持多种回归类型)
func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error {
return tc.DB.Transaction(func(tx *gorm.DB) error {
for i := range results {
results[i].TrainId = trainId
// 使用复合唯一约束确保每种回归类型只存储一条记录
err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"equation": results[i].Equation,
"slope": results[i].Slope,
"intercept": results[i].Intercept,
"log_a": results[i].LogA,
"log_b": results[i].LogB,
"quadratic_a": results[i].QuadraticA,
"quadratic_b": results[i].QuadraticB,
"quadratic_c": results[i].QuadraticC,
"r_squared": results[i].RSquared,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&results[i]).Error
if err != nil {
return err
}
}
return nil
})
}
// 获取或计算回归结果(返回多种回归类型列表)
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) {
// 尝试从数据库获取所有类型的回归结果
var results []models.RegressionResult
err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error
// 如果已存在三种类型的结果,直接返回
if err == nil && len(results) >= 3 {
return results, nil
}
// 查询训练记录及相关数据
var record models.StepTrainRecord
if err := tc.DB.
Where("train_id = ?", uint(trainId)).
Preload("HeartRates", "heart_rate_type = ?", 1).
Preload("StrideFreqs", "predict_value = ?", 1).
First(&record).Error; err != nil {
return nil, err
}
// 计算心率平均值
averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs)
if len(averages) == 0 {
return nil, errors.New("无足够数据进行回归计算")
}
// 创建三种回归类型的结果
results = make([]models.RegressionResult, 3)
// 线性回归
linearRes := performLinearRegression(averages)
results[0] = models.RegressionResult{
RegressionType: models.LinearRegression,
TrainId: trainId,
Equation: linearRes.Equation,
Slope: linearRes.Slope,
Intercept: linearRes.Intercept,
RSquared: linearRes.RSquared,
}
// 对数回归
logRes := performLogarithmicRegression(averages)
results[1] = models.RegressionResult{
RegressionType: models.LogarithmicRegression,
TrainId: trainId,
Equation: logRes.Equation,
LogA: logRes.LogA,
LogB: logRes.LogB,
RSquared: logRes.RSquared,
}
// 二次回归
quadRes := performQuadraticRegression(averages)
results[2] = models.RegressionResult{
RegressionType: models.QuadraticRegression,
TrainId: trainId,
Equation: quadRes.Equation,
QuadraticA: quadRes.QuadraticA,
QuadraticB: quadRes.QuadraticB,
QuadraticC: quadRes.QuadraticC,
RSquared: quadRes.RSquared,
}
// 批量保存结果到数据库
if err := tc.SaveRegressionResults(trainId, results); err != nil {
log.Printf("保存回归结果失败: %v", err)
return nil, err
}
return results, nil
}
// 新增接口:获取回归结果
func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) {
trainIdStr := c.Param("trainId")
tid, err := strconv.ParseUint(trainIdStr, 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
return
}
result, err := tc.GetOrCalculateRegression(uint(tid))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "获取成功",
"data": result,
})
}
// 获取训练记录的排名
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
// 解析参数
trainIdStr := c.Param("trainId")
regressionTypeStr := c.Query("type")
regressionType, err := strconv.Atoi(regressionTypeStr) // 字符串转整型
if err != nil {
// 转换失败时返回400错误
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
return
}
regType := models.RegressionType(regressionType)
// 验证回归类型
if regType != models.LinearRegression && regType != models.QuadraticRegression {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
return
}
// 转换trainId
tid, err := strconv.ParseUint(trainIdStr, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
return
}
trainId := uint(tid)
// 确保回归结果存在
if _, err := tc.GetOrCalculateRegression(trainId); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
return
}
// 获取所有记录用于排名
var records []models.RegressionResult
query := tc.DB.Model(&models.RegressionResult{})
switch regType {
case models.LinearRegression:
query = query.Where("slope IS NOT NULL").Select("train_id, slope,regression_type")
case models.QuadraticRegression:
query = query.Where("quadratic_a IS NOT NULL").
Select("train_id, ABS(quadratic_a),regression_type")
}
if err := query.Debug().Find(&records).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"})
return
}
// 处理无数据情况
if len(records) == 0 {
c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"})
return
}
sort.Slice(records, func(i, j int) bool {
// 处理空指针情况
slopeI := records[i].Slope
slopeJ := records[j].Slope
if slopeI == nil && slopeJ == nil {
return false // 两者均为空时视为相等
}
if slopeI == nil {
return false // 空值视为极大值(排在最后)
}
if slopeJ == nil {
return true // 非空值始终排在前
}
return *slopeI > *slopeJ
})
// 计算排名(处理并列)
currentRank := 1
rankMap := make(map[uint]int)
for i, record := range records {
// 处理第一个记录
if i == 0 {
rankMap[record.TrainId] = currentRank
continue
}
if regType == models.LinearRegression {
if records[i].Slope != records[i-1].Slope {
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1
}
rankMap[record.TrainId] = currentRank
} else {
if records[i].QuadraticA != records[i-1].QuadraticA {
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1
}
rankMap[record.TrainId] = currentRank
}
// 检测值是否变化
}
// 获取当前训练记录的排名
rank, exists := rankMap[trainId]
if !exists {
c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"})
return
}
// 返回响应
c.JSON(http.StatusOK, gin.H{
"message": "排名查询成功",
"data": gin.H{
"trainId": trainId,
"type": regressionType,
"rank": rank,
"total": len(records),
},
})
}