Compare commits
21 Commits
243361ee80
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 1a252a12be | |||
| 8d8dd26a2c | |||
| a14c553736 | |||
| 563fdd8a7e | |||
| d3e87fda67 | |||
| 3078c13e14 | |||
| 7a2a44e327 | |||
| b00767dfcd | |||
| 2aa22b1385 | |||
| 07963218cd | |||
| 0852f4bc23 | |||
| 19148d7d35 | |||
| 4e10359a5b | |||
| 8510d19baa | |||
| f4df1724b2 | |||
| 5cca15808d | |||
| d69aaa5704 | |||
| 6e8232ff7f | |||
| 7b9e870bf9 | |||
| 22dc82e052 | |||
| 4bb2dff5ea |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
.idea
|
.idea
|
||||||
hr_receiver.iml
|
hr_receiver.iml
|
||||||
main.go.bak
|
main.go.bak
|
||||||
|
config.yaml
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
database:
|
database:
|
||||||
host: localhost
|
host: localhost #when use docker change to "db"
|
||||||
port: 5432
|
port: 5432
|
||||||
user: postgres
|
user: postgres
|
||||||
password: root
|
password: root
|
||||||
895
controllers/step_train.go
Normal file
895
controllers/step_train.go
Normal file
@ -0,0 +1,895 @@
|
|||||||
|
package controllers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/sajari/regression"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"hr_receiver/config"
|
||||||
|
"hr_receiver/models"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StepTrainingController struct {
|
||||||
|
DB *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStepTrainingController() *StepTrainingController {
|
||||||
|
return &StepTrainingController{DB: config.DB}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 接收训练记录
|
||||||
|
func (tc *StepTrainingController) CreateTrainingRecord(c *gin.Context) {
|
||||||
|
var record models.StepTrainRecord
|
||||||
|
|
||||||
|
// 绑定并验证JSON数据
|
||||||
|
if err := c.ShouldBindJSON(&record); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用事务保存数据[4](@ref)
|
||||||
|
err := tc.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 保存主记录
|
||||||
|
if err := tx.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||||
|
"max_heart_rate": record.MaxHeartRate,
|
||||||
|
"start_time": record.StartTime,
|
||||||
|
"end_time": record.EndTime,
|
||||||
|
"duration": record.Duration,
|
||||||
|
"dead_zone": record.DeadZone,
|
||||||
|
"name": record.Name,
|
||||||
|
"evaluation": record.Evaluation,
|
||||||
|
}),
|
||||||
|
}).Omit("HeartRates", "StrideFreqs").Create(&record).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存关联的心率数据
|
||||||
|
for i := range record.HeartRates {
|
||||||
|
if err := tx.Clauses(
|
||||||
|
clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{"heart_rate_type": record.HeartRates[i].HeartRateType, "value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
|
||||||
|
},
|
||||||
|
).Create(&record.HeartRates[i]).Error; err != nil {
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := range record.StrideFreqs {
|
||||||
|
if err := tx.Clauses(
|
||||||
|
clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.StrideFreqs[i].Value, "time": record.StrideFreqs[i].Time}),
|
||||||
|
},
|
||||||
|
).Create(&record.StrideFreqs[i]).Error; err != nil {
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// ====== 新增部分:启动异步回归计算 ======
|
||||||
|
go func() {
|
||||||
|
// 查询完整数据(需要关联的心率和步频数据)
|
||||||
|
var fullRecord models.StepTrainRecord
|
||||||
|
if err := tc.DB.
|
||||||
|
Where("train_id = ?", record.TrainId).
|
||||||
|
Preload("HeartRates", "heart_rate_type = ?", 1). // 只要有效心率
|
||||||
|
Preload("StrideFreqs", "predict_value = ?", 1). // 只要有效步频
|
||||||
|
First(&fullRecord).Error; err != nil {
|
||||||
|
log.Printf("训练记录%d查询失败,无法计算回归: %v", record.TrainId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查数据是否满足计算条件
|
||||||
|
if len(fullRecord.HeartRates) == 0 || len(fullRecord.StrideFreqs) == 0 {
|
||||||
|
log.Printf("训练记录%d缺少心率或步频数据,跳过回归计算", record.TrainId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算并保存回归结果
|
||||||
|
if _, err := tc.GetOrCalculateRegression(fullRecord.TrainId); err != nil {
|
||||||
|
log.Printf("训练记录%d回归计算失败: %v", fullRecord.TrainId, err)
|
||||||
|
} else {
|
||||||
|
log.Printf("训练记录%d回归结果已保存", fullRecord.TrainId)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusCreated, gin.H{
|
||||||
|
"message": "数据保存成功",
|
||||||
|
"id": record.TrainId,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *StepTrainingController) GetTrainingRecords(c *gin.Context) {
|
||||||
|
// 定义分页参数结构
|
||||||
|
type PaginationParams struct {
|
||||||
|
PageNum int `form:"pageNum,default=1"` // 页码,默认第一页
|
||||||
|
PageSize int `form:"pageSize,default=10"` // 每页数量,默认10条
|
||||||
|
}
|
||||||
|
|
||||||
|
var params PaginationParams
|
||||||
|
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证分页参数有效性
|
||||||
|
if params.PageNum < 1 {
|
||||||
|
params.PageNum = 1
|
||||||
|
}
|
||||||
|
if params.PageSize < 1 || params.PageSize > 100 {
|
||||||
|
params.PageSize = 10
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算偏移量
|
||||||
|
offset := (params.PageNum - 1) * params.PageSize
|
||||||
|
|
||||||
|
var (
|
||||||
|
records []models.StepTrainRecord
|
||||||
|
totalRows int64
|
||||||
|
)
|
||||||
|
|
||||||
|
// 获取总记录数
|
||||||
|
if err := tc.DB.Model(&models.StepTrainRecord{}).Count(&totalRows).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取记录总数失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询分页数据(按开始时间倒序排列)
|
||||||
|
result := tc.DB.
|
||||||
|
Order("start_time DESC"). // 按开始时间倒序
|
||||||
|
Offset(offset).
|
||||||
|
Limit(params.PageSize).
|
||||||
|
Find(&records)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算总页数
|
||||||
|
totalPages := int(math.Ceil(float64(totalRows) / float64(params.PageSize)))
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "查询成功",
|
||||||
|
"data": gin.H{
|
||||||
|
"list": records,
|
||||||
|
"pagination": gin.H{
|
||||||
|
"currentPage": params.PageNum,
|
||||||
|
"pageSize": params.PageSize,
|
||||||
|
"totalPage": totalPages,
|
||||||
|
"totalList": totalRows,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func (tc *StepTrainingController) GetTrainingRecordByTrainId(c *gin.Context) {
|
||||||
|
// 从URL路径参数获取trainId
|
||||||
|
trainId := c.Param("trainId")
|
||||||
|
if trainId == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "训练ID不能为空"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将字符串trainId转换为uint类型
|
||||||
|
tid, err := strconv.ParseInt(trainId, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID格式"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var record models.StepTrainRecord
|
||||||
|
|
||||||
|
// 查询主记录并预加载关联的心率和步频数据
|
||||||
|
result := tc.DB.Where("train_id = ?", uint(tid)).
|
||||||
|
Preload("HeartRates").
|
||||||
|
Preload("StrideFreqs").
|
||||||
|
First(&record)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "训练记录不存在"})
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 成功返回数据
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "查询成功",
|
||||||
|
"data": record,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 定义结构体
|
||||||
|
type SpeedSegment struct {
|
||||||
|
Duration float64
|
||||||
|
Speed float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现线性回归算法
|
||||||
|
func performLinearRegression(averages []map[float64]float64) models.RegressionResult {
|
||||||
|
if len(averages) == 0 {
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "无数据",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集数据点
|
||||||
|
var points []struct{ x, y float64 }
|
||||||
|
for _, m := range averages {
|
||||||
|
for x, y := range m {
|
||||||
|
points = append(points, struct{ x, y float64 }{x, y})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用回归库计算
|
||||||
|
r := new(regression.Regression)
|
||||||
|
r.SetObserved("y")
|
||||||
|
r.SetVar(0, "x")
|
||||||
|
|
||||||
|
for _, p := range points {
|
||||||
|
r.Train(regression.DataPoint(p.y, []float64{p.x}))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.Run(); err != nil {
|
||||||
|
log.Printf("线性回归计算失败: %v", err)
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "计算失败",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建结果
|
||||||
|
slope := r.Coeff(1)
|
||||||
|
intercept := r.Coeff(0)
|
||||||
|
r2 := r.R2
|
||||||
|
return models.RegressionResult{
|
||||||
|
RegressionType: models.LinearRegression,
|
||||||
|
Slope: &slope,
|
||||||
|
Intercept: &intercept,
|
||||||
|
RSquared: &r2,
|
||||||
|
Equation: r.Formula,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实现对数和二次回归算法
|
||||||
|
// 对数回归算法
|
||||||
|
func performLogarithmicRegression(averages []map[float64]float64) models.RegressionResult {
|
||||||
|
if len(averages) == 0 {
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "无数据",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集数据点
|
||||||
|
r := new(regression.Regression)
|
||||||
|
r.SetObserved("y")
|
||||||
|
r.SetVar(0, "log(x+1)")
|
||||||
|
|
||||||
|
for _, m := range averages {
|
||||||
|
for speed, hr := range m {
|
||||||
|
logSpeed := math.Log(speed + 1)
|
||||||
|
r.Train(regression.DataPoint(hr, []float64{logSpeed}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.Run(); err != nil {
|
||||||
|
log.Printf("对数回归计算失败: %v", err)
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "计算失败",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建结果
|
||||||
|
logA := r.Coeff(1)
|
||||||
|
logB := r.Coeff(0)
|
||||||
|
r2 := r.R2
|
||||||
|
return models.RegressionResult{
|
||||||
|
RegressionType: models.LogarithmicRegression,
|
||||||
|
LogA: &logA,
|
||||||
|
LogB: &logB,
|
||||||
|
RSquared: &r2,
|
||||||
|
Equation: r.Formula,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 二次回归算法
|
||||||
|
//func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
|
||||||
|
// if len(averages) == 0 {
|
||||||
|
// return models.RegressionResult{
|
||||||
|
// Equation: "无数据",
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // 收集数据点
|
||||||
|
// r := new(regression.Regression)
|
||||||
|
// r.SetObserved("y")
|
||||||
|
// r.SetVar(0, "x")
|
||||||
|
// r.SetVar(1, "x²")
|
||||||
|
//
|
||||||
|
// for _, m := range averages {
|
||||||
|
// for speed, hr := range m {
|
||||||
|
// speedSq := math.Pow(speed, 2)
|
||||||
|
// r.Train(regression.DataPoint(hr, []float64{speed, speedSq}))
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// if err := r.Run(); err != nil {
|
||||||
|
// log.Printf("二次回归计算失败: %v", err)
|
||||||
|
// return models.RegressionResult{
|
||||||
|
// Equation: "计算失败",
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // 创建结果
|
||||||
|
// a := r.Coeff(2)
|
||||||
|
// b := r.Coeff(1)
|
||||||
|
// c := r.Coeff(0)
|
||||||
|
// r2 := r.R2
|
||||||
|
// return models.RegressionResult{
|
||||||
|
// RegressionType: models.QuadraticRegression,
|
||||||
|
// QuadraticA: &a,
|
||||||
|
// QuadraticB: &b,
|
||||||
|
// QuadraticC: &c,
|
||||||
|
// RSquared: &r2,
|
||||||
|
// Equation: r.Formula,
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
|
||||||
|
if len(averages) == 0 {
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "无数据",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤1:收集所有数据点(与Flutter一致)
|
||||||
|
var xValues []float64
|
||||||
|
var yValues []float64
|
||||||
|
for _, m := range averages {
|
||||||
|
for speed, hr := range m {
|
||||||
|
xValues = append(xValues, speed)
|
||||||
|
yValues = append(yValues, hr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n := float64(len(xValues))
|
||||||
|
|
||||||
|
// 步骤2:计算各项和(完全匹配Flutter的计算)
|
||||||
|
var sumX, sumY, sumX2, sumX3, sumX4, sumXY, sumX2Y float64
|
||||||
|
for i := 0; i < len(xValues); i++ {
|
||||||
|
x := xValues[i]
|
||||||
|
y := yValues[i]
|
||||||
|
x2 := x * x
|
||||||
|
x3 := x2 * x
|
||||||
|
x4 := x3 * x
|
||||||
|
|
||||||
|
sumX += x
|
||||||
|
sumY += y
|
||||||
|
sumX2 += x2
|
||||||
|
sumX3 += x3
|
||||||
|
sumX4 += x4
|
||||||
|
sumXY += x * y
|
||||||
|
sumX2Y += x2 * y
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤3:构建正规方程矩阵(与Flutter完全一致)
|
||||||
|
matrix := [3][3]float64{
|
||||||
|
{n, sumX, sumX2},
|
||||||
|
{sumX, sumX2, sumX3},
|
||||||
|
{sumX2, sumX3, sumX4},
|
||||||
|
}
|
||||||
|
vector := []float64{sumY, sumXY, sumX2Y}
|
||||||
|
|
||||||
|
// 步骤4:计算矩阵行列式(复制Flutter的determinant3x3逻辑)
|
||||||
|
det := matrix[0][0]*(matrix[1][1]*matrix[2][2]-matrix[1][2]*matrix[2][1]) -
|
||||||
|
matrix[0][1]*(matrix[1][0]*matrix[2][2]-matrix[1][2]*matrix[2][0]) +
|
||||||
|
matrix[0][2]*(matrix[1][0]*matrix[2][1]-matrix[1][1]*matrix[2][0])
|
||||||
|
|
||||||
|
if det == 0 {
|
||||||
|
return models.RegressionResult{
|
||||||
|
Equation: "无法拟合",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤5:克莱姆法则求解系数(顺序与Flutter一致)
|
||||||
|
// 注意:最终系数顺序 a=二次项, b=一次项, c=常数项
|
||||||
|
c := det3x3([3][3]float64{
|
||||||
|
{vector[0], matrix[0][1], matrix[0][2]},
|
||||||
|
{vector[1], matrix[1][1], matrix[1][2]},
|
||||||
|
{vector[2], matrix[2][1], matrix[2][2]},
|
||||||
|
}) / det
|
||||||
|
|
||||||
|
b := det3x3([3][3]float64{
|
||||||
|
{matrix[0][0], vector[0], matrix[0][2]},
|
||||||
|
{matrix[1][0], vector[1], matrix[1][2]},
|
||||||
|
{matrix[2][0], vector[2], matrix[2][2]},
|
||||||
|
}) / det
|
||||||
|
|
||||||
|
a := det3x3([3][3]float64{
|
||||||
|
{matrix[0][0], matrix[0][1], vector[0]},
|
||||||
|
{matrix[1][0], matrix[1][1], vector[1]},
|
||||||
|
{matrix[2][0], matrix[2][1], vector[2]},
|
||||||
|
}) / det
|
||||||
|
|
||||||
|
// 步骤6:计算R平方(完全复制Flutter的计算逻辑)
|
||||||
|
var ssRes, ssTot float64
|
||||||
|
meanY := sumY / n
|
||||||
|
for i := 0; i < len(xValues); i++ {
|
||||||
|
x := xValues[i]
|
||||||
|
y := yValues[i]
|
||||||
|
yPred := a*x*x + b*x + c
|
||||||
|
ssRes += math.Pow(y-yPred, 2)
|
||||||
|
ssTot += math.Pow(y-meanY, 2)
|
||||||
|
}
|
||||||
|
rSquared := 0.0
|
||||||
|
if ssTot != 0 {
|
||||||
|
rSquared = 1 - ssRes/ssTot
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步骤7:格式化公式字符串(与Flutter格式完全一致)
|
||||||
|
equation := formatEquation(a, b, c, rSquared)
|
||||||
|
|
||||||
|
return models.RegressionResult{
|
||||||
|
RegressionType: models.QuadraticRegression,
|
||||||
|
QuadraticA: &a,
|
||||||
|
QuadraticB: &b,
|
||||||
|
QuadraticC: &c,
|
||||||
|
RSquared: &rSquared,
|
||||||
|
Equation: equation,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3x3行列式计算(与Flutter实现相同)
|
||||||
|
func det3x3(m [3][3]float64) float64 {
|
||||||
|
return m[0][0]*(m[1][1]*m[2][2]-m[1][2]*m[2][1]) -
|
||||||
|
m[0][1]*(m[1][0]*m[2][2]-m[1][2]*m[2][0]) +
|
||||||
|
m[0][2]*(m[1][0]*m[2][1]-m[1][1]*m[2][0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// 公式格式化(完全匹配Flutter格式)
|
||||||
|
func formatEquation(a, b, c, r2 float64) string {
|
||||||
|
// 保留4位小数
|
||||||
|
aStr := fmt.Sprintf("%.4f", a)
|
||||||
|
bStr := fmt.Sprintf("%.4f", b)
|
||||||
|
cStr := fmt.Sprintf("%.4f", c)
|
||||||
|
r2Str := fmt.Sprintf("%.4f", r2)
|
||||||
|
|
||||||
|
builder := strings.Builder{}
|
||||||
|
builder.WriteString("y = ")
|
||||||
|
|
||||||
|
// 处理二次项
|
||||||
|
if a >= 0 {
|
||||||
|
builder.WriteString(aStr + " x²")
|
||||||
|
} else {
|
||||||
|
builder.WriteString("-" + strings.TrimPrefix(aStr, "-") + " x²")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理一次项
|
||||||
|
if b >= 0 {
|
||||||
|
builder.WriteString(" + " + bStr + " x")
|
||||||
|
} else {
|
||||||
|
builder.WriteString(" - " + strings.TrimPrefix(bStr, "-") + " x")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理常数项
|
||||||
|
if c >= 0 {
|
||||||
|
builder.WriteString(" + " + cStr)
|
||||||
|
} else {
|
||||||
|
builder.WriteString(" - " + strings.TrimPrefix(cStr, "-"))
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.WriteString(" (R² = " + r2Str + ")")
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 步频数据转换为速度段
|
||||||
|
func convertStrideFrequencyToSegments(steps []models.StepStrideFreq) []SpeedSegment {
|
||||||
|
if len(steps) == 0 {
|
||||||
|
return []SpeedSegment{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 过滤零值并排序
|
||||||
|
validSteps := make([]models.StepStrideFreq, 0, len(steps))
|
||||||
|
for _, s := range steps {
|
||||||
|
if s.Value > 0 {
|
||||||
|
validSteps = append(validSteps, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validSteps) == 0 {
|
||||||
|
return []SpeedSegment{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按时间排序
|
||||||
|
for i := 0; i < len(validSteps)-1; i++ {
|
||||||
|
for j := i + 1; j < len(validSteps); j++ {
|
||||||
|
if validSteps[i].Time > validSteps[j].Time {
|
||||||
|
validSteps[i], validSteps[j] = validSteps[j], validSteps[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建速度段
|
||||||
|
segments := make([]SpeedSegment, 0)
|
||||||
|
startTime := validSteps[0].Time
|
||||||
|
currentValue := validSteps[0].Value
|
||||||
|
|
||||||
|
for i := 1; i < len(validSteps); i++ {
|
||||||
|
if validSteps[i].Value != currentValue {
|
||||||
|
duration := float64(validSteps[i].Time-startTime) / 1000.0
|
||||||
|
if duration > 0 {
|
||||||
|
segments = append(segments, SpeedSegment{
|
||||||
|
Duration: duration,
|
||||||
|
Speed: float64(currentValue),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
startTime = validSteps[i].Time
|
||||||
|
currentValue = validSteps[i].Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加最后一个段
|
||||||
|
if len(validSteps) > 0 {
|
||||||
|
duration := float64(validSteps[len(validSteps)-1].Time-startTime) / 1000.0
|
||||||
|
if duration > 0 {
|
||||||
|
segments = append(segments, SpeedSegment{
|
||||||
|
Duration: duration,
|
||||||
|
Speed: float64(currentValue),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return segments
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算区段平均值
|
||||||
|
func calculateSegmentAverages(heartRates []models.StepHeartRate, segments []SpeedSegment, errorThreshold int) []map[float64]float64 {
|
||||||
|
currentTime := 0.0
|
||||||
|
results := make([]map[float64]float64, 0)
|
||||||
|
|
||||||
|
for _, seg := range segments {
|
||||||
|
minRequired := 60 + (60 - float64(errorThreshold))
|
||||||
|
|
||||||
|
// 跳过不满足条件的区段
|
||||||
|
if seg.Duration < minRequired {
|
||||||
|
currentTime += seg.Duration
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算时间窗口
|
||||||
|
startSec := currentTime + 60
|
||||||
|
endSec := currentTime
|
||||||
|
if seg.Duration >= 120 {
|
||||||
|
endSec = currentTime + 120
|
||||||
|
} else {
|
||||||
|
endSec = currentTime + 120 - float64(errorThreshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集该区段的心率数据
|
||||||
|
sum, count := 0, 0
|
||||||
|
for _, hr := range heartRates {
|
||||||
|
sec := float64(hr.Time) / 1000.0
|
||||||
|
if sec >= startSec && sec <= endSec {
|
||||||
|
sum += hr.Value
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算平均值
|
||||||
|
if count > 0 {
|
||||||
|
avg := float64(sum) / float64(count)
|
||||||
|
results = append(results, map[float64]float64{seg.Speed: avg})
|
||||||
|
}
|
||||||
|
|
||||||
|
currentTime += seg.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算步频区段的心率平均值
|
||||||
|
func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps []models.StepStrideFreq) []map[float64]float64 {
|
||||||
|
segments := convertStrideFrequencyToSegments(steps)
|
||||||
|
return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值
|
||||||
|
}
|
||||||
|
|
||||||
|
// 存储回归结果到数据库(支持多种回归类型)
|
||||||
|
func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error {
|
||||||
|
return tc.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
for i := range results {
|
||||||
|
results[i].TrainId = trainId
|
||||||
|
// 使用复合唯一约束确保每种回归类型只存储一条记录
|
||||||
|
err := tx.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{
|
||||||
|
{Name: "id"},
|
||||||
|
},
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||||
|
"equation": results[i].Equation,
|
||||||
|
"slope": results[i].Slope,
|
||||||
|
"intercept": results[i].Intercept,
|
||||||
|
"log_a": results[i].LogA,
|
||||||
|
"log_b": results[i].LogB,
|
||||||
|
"quadratic_a": results[i].QuadraticA,
|
||||||
|
"quadratic_b": results[i].QuadraticB,
|
||||||
|
"quadratic_c": results[i].QuadraticC,
|
||||||
|
"r_squared": results[i].RSquared,
|
||||||
|
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||||
|
}),
|
||||||
|
}).Create(&results[i]).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取或计算回归结果(返回多种回归类型列表)
|
||||||
|
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) {
|
||||||
|
// 尝试从数据库获取所有类型的回归结果
|
||||||
|
var results []models.RegressionResult
|
||||||
|
err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error
|
||||||
|
|
||||||
|
// 如果已存在三种类型的结果,直接返回
|
||||||
|
if err == nil && len(results) >= 3 {
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询训练记录及相关数据
|
||||||
|
var record models.StepTrainRecord
|
||||||
|
if err := tc.DB.
|
||||||
|
Where("train_id = ?", uint(trainId)).
|
||||||
|
Preload("HeartRates", "heart_rate_type = ?", 1).
|
||||||
|
Preload("StrideFreqs", "predict_value = ?", 1).
|
||||||
|
First(&record).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算心率平均值
|
||||||
|
averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs)
|
||||||
|
if len(averages) == 0 {
|
||||||
|
return nil, errors.New("无足够数据进行回归计算")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建三种回归类型的结果
|
||||||
|
results = make([]models.RegressionResult, 3)
|
||||||
|
|
||||||
|
// 线性回归
|
||||||
|
linearRes := performLinearRegression(averages)
|
||||||
|
results[0] = models.RegressionResult{
|
||||||
|
RegressionType: models.LinearRegression,
|
||||||
|
TrainId: trainId,
|
||||||
|
Equation: linearRes.Equation,
|
||||||
|
Slope: linearRes.Slope,
|
||||||
|
Intercept: linearRes.Intercept,
|
||||||
|
RSquared: linearRes.RSquared,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对数回归
|
||||||
|
logRes := performLogarithmicRegression(averages)
|
||||||
|
results[1] = models.RegressionResult{
|
||||||
|
RegressionType: models.LogarithmicRegression,
|
||||||
|
TrainId: trainId,
|
||||||
|
Equation: logRes.Equation,
|
||||||
|
LogA: logRes.LogA,
|
||||||
|
LogB: logRes.LogB,
|
||||||
|
RSquared: logRes.RSquared,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 二次回归
|
||||||
|
quadRes := performQuadraticRegression(averages)
|
||||||
|
results[2] = models.RegressionResult{
|
||||||
|
RegressionType: models.QuadraticRegression,
|
||||||
|
TrainId: trainId,
|
||||||
|
Equation: quadRes.Equation,
|
||||||
|
QuadraticA: quadRes.QuadraticA,
|
||||||
|
QuadraticB: quadRes.QuadraticB,
|
||||||
|
QuadraticC: quadRes.QuadraticC,
|
||||||
|
RSquared: quadRes.RSquared,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量保存结果到数据库
|
||||||
|
if err := tc.SaveRegressionResults(trainId, results); err != nil {
|
||||||
|
log.Printf("保存回归结果失败: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 新增接口:获取回归结果
|
||||||
|
func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) {
|
||||||
|
trainIdStr := c.Param("trainId")
|
||||||
|
tid, err := strconv.ParseUint(trainIdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := tc.GetOrCalculateRegression(uint(tid))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "获取成功",
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
|
||||||
|
// 参数解析
|
||||||
|
trainIdStr := c.Param("trainId")
|
||||||
|
regressionTypeStr := c.Query("type")
|
||||||
|
regressionType, err := strconv.Atoi(regressionTypeStr)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证回归类型
|
||||||
|
regType := models.RegressionType(regressionType)
|
||||||
|
if regType != models.LinearRegression && regType != models.QuadraticRegression {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换训练ID
|
||||||
|
tid, err := strconv.ParseUint(trainIdStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
trainId := uint(tid)
|
||||||
|
|
||||||
|
// 确保回归结果存在
|
||||||
|
if _, err := tc.GetOrCalculateRegression(trainId); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 获取指定训练的基准值
|
||||||
|
var baseValue float64
|
||||||
|
baseQuery := tc.DB.Model(&models.RegressionResult{}).
|
||||||
|
Select(getValueColumn(regType)).
|
||||||
|
Where("train_id = ?", trainId)
|
||||||
|
|
||||||
|
if err := baseQuery.Row().Scan(&baseValue); err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "指定的训练数据不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取基准数据失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 在数据库中进行排名计算
|
||||||
|
var rank struct {
|
||||||
|
BetterCount int64
|
||||||
|
Total int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// 动态生成比较条件
|
||||||
|
betterCondition := fmt.Sprintf("%s %s ?",
|
||||||
|
getValueColumn(regType),
|
||||||
|
getComparisonOperator(regType))
|
||||||
|
|
||||||
|
totalQuery := tc.DB.Model(&models.RegressionResult{}).
|
||||||
|
Where(getTypeCondition(regType))
|
||||||
|
|
||||||
|
if err := totalQuery.Count(&rank.Total).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "统计总数失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tc.DB.Model(&models.RegressionResult{}).
|
||||||
|
Where(getTypeCondition(regType)).
|
||||||
|
Where(betterCondition, baseValue).
|
||||||
|
Count(&rank.BetterCount).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "计算排名失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算实际排名 (并列排名)
|
||||||
|
currentRank := rank.BetterCount + 1
|
||||||
|
|
||||||
|
// 返回响应
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "排名查询成功",
|
||||||
|
"data": gin.H{
|
||||||
|
"trainId": trainId,
|
||||||
|
"type": regressionType,
|
||||||
|
"rank": currentRank,
|
||||||
|
"total": rank.Total,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:获取排序字段名
|
||||||
|
func getValueColumn(regType models.RegressionType) string {
|
||||||
|
switch regType {
|
||||||
|
case models.LinearRegression:
|
||||||
|
return "slope"
|
||||||
|
case models.QuadraticRegression:
|
||||||
|
return "ABS(quadratic_a)" // 计算绝对值
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:获取比较操作符
|
||||||
|
func getComparisonOperator(regType models.RegressionType) string {
|
||||||
|
switch regType {
|
||||||
|
case models.LinearRegression:
|
||||||
|
return "<" // 线性回归:值越小越好
|
||||||
|
case models.QuadraticRegression:
|
||||||
|
return ">" // 二次回归:绝对值越大越好
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:获取类型条件
|
||||||
|
func getTypeCondition(regType models.RegressionType) string {
|
||||||
|
switch regType {
|
||||||
|
case models.LinearRegression:
|
||||||
|
return "slope IS NOT NULL"
|
||||||
|
case models.QuadraticRegression:
|
||||||
|
return "quadratic_a IS NOT NULL AND quadratic_a < 0" // 确保是负值
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:比较浮点指针(用于线性回归)
|
||||||
|
func compareFloatPtr(a, b *float64, ascending bool) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a == nil {
|
||||||
|
return false // 空值排最后
|
||||||
|
}
|
||||||
|
if b == nil {
|
||||||
|
return true // 非空值排前
|
||||||
|
}
|
||||||
|
|
||||||
|
if ascending {
|
||||||
|
return *a < *b
|
||||||
|
}
|
||||||
|
return *a > *b
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:比较二次项系数(用于二次回归)
|
||||||
|
func compareQuadraticA(a, b *float64) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 比较绝对值(a和b都是负值,所以取绝对值后大的排前面)
|
||||||
|
return math.Abs(*a) > math.Abs(*b)
|
||||||
|
}
|
||||||
@ -1,14 +1,33 @@
|
|||||||
package controllers
|
package controllers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"gonum.org/v1/gonum/floats"
|
||||||
|
"gonum.org/v1/gonum/stat"
|
||||||
|
"gonum.org/v1/gonum/stat/distuv"
|
||||||
|
)
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"hr_receiver/config"
|
"hr_receiver/config"
|
||||||
"hr_receiver/models"
|
"hr_receiver/models"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var analyzeRunTypes = []string{"6.5开始", "7开始", "8开始"} // 替换为你的具体值
|
||||||
|
|
||||||
|
func contains(s string) bool {
|
||||||
|
for _, item := range analyzeRunTypes {
|
||||||
|
if item == s {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type TrainingController struct {
|
type TrainingController struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
}
|
}
|
||||||
@ -57,13 +76,12 @@ func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if contains(record.RunType) {
|
||||||
//// 保存腰带关联关系
|
err := tc.heartRateAnalyze(tx, record)
|
||||||
//if len(record.Belts) > 0 {
|
if err != nil {
|
||||||
// if err := tx.Model(&record).Association("Belts").Replace(record.Belts); err != nil {
|
return err
|
||||||
// return err
|
}
|
||||||
// }
|
}
|
||||||
//}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -73,28 +91,276 @@ func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, record)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReceiveTrainingData(c *gin.Context) {
|
|
||||||
var data models.TrainingData
|
|
||||||
|
|
||||||
if err := c.ShouldBindJSON(&data); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
|
||||||
"error": "Invalid request body: " + err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if result := config.DB.Create(&data); result.Error != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
|
||||||
"error": "Failed to save data: " + result.Error.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusCreated, gin.H{
|
c.JSON(http.StatusCreated, gin.H{
|
||||||
"message": "Data saved successfully",
|
"message": "数据保存成功",
|
||||||
"id": data.ID,
|
"id": record.TrainId,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// analysis_response.go
|
||||||
|
type AnalysisResponse struct {
|
||||||
|
Status string `json:"status"` // 状态码
|
||||||
|
Message string `json:"message"` // 附加信息
|
||||||
|
Data struct {
|
||||||
|
Mean float64 `json:"mean"` // 均值
|
||||||
|
StdDev float64 `json:"stdDev"` // 标准差
|
||||||
|
Histogram []HistoBin `json:"histogram"` // 直方图数据
|
||||||
|
Curve []CurvePoint `json:"curve"` // 正态曲线数据
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HistoBin struct {
|
||||||
|
BinStart float64 `json:"binStart"` // 区间起始值
|
||||||
|
BinEnd float64 `json:"binEnd"` // 区间结束值
|
||||||
|
Count int `json:"count"` // 该区间计数
|
||||||
|
}
|
||||||
|
|
||||||
|
type CurvePoint struct {
|
||||||
|
X float64 `json:"x"` // X坐标
|
||||||
|
Y float64 `json:"y"` // Y坐标
|
||||||
|
}
|
||||||
|
|
||||||
|
// analysis_handler.go
|
||||||
|
func (tc *TrainingController) HandleCurveAnalysis(c *gin.Context) {
|
||||||
|
// 获取数据库连接(根据实际项目配置调整)
|
||||||
|
|
||||||
|
// 1. 获取历史数据
|
||||||
|
aValues, err := collectCurveParams(tc.DB)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": "error",
|
||||||
|
"message": "数据查询失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 检查数据有效性
|
||||||
|
if len(aValues) < 10 { // 至少需要10个样本
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"status": "fail",
|
||||||
|
"message": "数据量不足,至少需要10个样本",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 计算统计量
|
||||||
|
mean, stddev := calculateStats(aValues)
|
||||||
|
|
||||||
|
// 4. 生成直方图数据
|
||||||
|
histogram := calculateHistogram(aValues, 20) // 20个分箱
|
||||||
|
|
||||||
|
// 5. 生成正态曲线
|
||||||
|
x, y := generateNormalCurve(mean, stddev, 100)
|
||||||
|
|
||||||
|
// 6. 构造响应
|
||||||
|
response := AnalysisResponse{
|
||||||
|
Status: "success",
|
||||||
|
Message: "分析完成",
|
||||||
|
Data: struct {
|
||||||
|
Mean float64 `json:"mean"`
|
||||||
|
StdDev float64 `json:"stdDev"`
|
||||||
|
Histogram []HistoBin `json:"histogram"`
|
||||||
|
Curve []CurvePoint `json:"curve"`
|
||||||
|
}{
|
||||||
|
Mean: mean,
|
||||||
|
StdDev: stddev,
|
||||||
|
Histogram: histogram,
|
||||||
|
Curve: convertToCurvePoints(x, y),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直方图计算函数
|
||||||
|
func calculateHistogram(data []float64, bins int) []HistoBin {
|
||||||
|
minV, maxV := floats.Min(data), floats.Max(data)
|
||||||
|
binWidth := (maxV - minV) / float64(bins)
|
||||||
|
|
||||||
|
counts := make([]int, bins)
|
||||||
|
for _, v := range data {
|
||||||
|
idx := int((v - minV) / binWidth)
|
||||||
|
if idx == bins { // 处理最大值刚好等于maxV的情况
|
||||||
|
idx--
|
||||||
|
}
|
||||||
|
counts[idx]++
|
||||||
|
}
|
||||||
|
|
||||||
|
histogram := make([]HistoBin, bins)
|
||||||
|
for i := 0; i < bins; i++ {
|
||||||
|
start := minV + float64(i)*binWidth
|
||||||
|
end := minV + float64(i+1)*binWidth
|
||||||
|
histogram[i] = HistoBin{
|
||||||
|
BinStart: start,
|
||||||
|
BinEnd: end,
|
||||||
|
Count: counts[i],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return histogram
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换曲线数据格式
|
||||||
|
func convertToCurvePoints(x, y []float64) []CurvePoint {
|
||||||
|
points := make([]CurvePoint, len(x))
|
||||||
|
for i := range x {
|
||||||
|
points[i] = CurvePoint{
|
||||||
|
X: x[i],
|
||||||
|
Y: y[i],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return points
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error {
|
||||||
|
var startTime int64
|
||||||
|
if record.TestTime > 0 {
|
||||||
|
startTime = record.TestTime
|
||||||
|
} else {
|
||||||
|
startTime = record.StartTime
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取所有唯一的beltID
|
||||||
|
var beltIDs []uint
|
||||||
|
tx.Model(&models.HeartRate{}).Where("train_id = ?", record.TrainId).
|
||||||
|
Select("DISTINCT belt_id").Pluck("belt_id", &beltIDs)
|
||||||
|
|
||||||
|
// 对每个belt计算
|
||||||
|
for _, bid := range beltIDs {
|
||||||
|
// 计算平均心率
|
||||||
|
ranges := getTimeRanges(startTime)
|
||||||
|
averages, err := calculateAverages(tx, record.TrainId, bid, ranges)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 曲线拟合
|
||||||
|
x := []float64{2, 4, 6}
|
||||||
|
y := []float64{averages["2min"], averages["4min"], averages["6min"]}
|
||||||
|
a, b, _ := quadraticFit(x, y)
|
||||||
|
|
||||||
|
// 存储结果
|
||||||
|
analysis := models.BeltAnalysis{
|
||||||
|
TrainID: record.TrainId,
|
||||||
|
RunType: record.RunType,
|
||||||
|
BeltID: bid,
|
||||||
|
Avg2min: averages["2min"],
|
||||||
|
Avg4min: averages["4min"],
|
||||||
|
Avg6min: averages["6min"],
|
||||||
|
CurveParamA: a,
|
||||||
|
CurveParamB: b,
|
||||||
|
}
|
||||||
|
if err := tx.Create(&analysis).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectCurveParams(tx *gorm.DB) ([]float64, error) {
|
||||||
|
var aValues []float64
|
||||||
|
// 查询所有记录的 CurveParamA 字段
|
||||||
|
err := tx.Model(&models.BeltAnalysis{}).Pluck("curve_param_a", &aValues).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return aValues, nil
|
||||||
|
}
|
||||||
|
func calculateStats(data []float64) (mean, stddev float64) {
|
||||||
|
mean = stat.Mean(data, nil)
|
||||||
|
variance := stat.Variance(data, nil)
|
||||||
|
stddev = math.Sqrt(variance)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateNormalCurve(mean, stddev float64, numPoints int) (x, y []float64) {
|
||||||
|
normal := distuv.Normal{
|
||||||
|
Mu: mean,
|
||||||
|
Sigma: stddev,
|
||||||
|
}
|
||||||
|
|
||||||
|
minV := mean - 3*stddev // 从均值-3σ开始
|
||||||
|
maxV := mean + 3*stddev // 到均值+3σ结束
|
||||||
|
step := (maxV - minV) / float64(numPoints-1)
|
||||||
|
|
||||||
|
for i := 0; i < numPoints; i++ {
|
||||||
|
xi := minV + float64(i)*step
|
||||||
|
yi := normal.Prob(xi)
|
||||||
|
x = append(x, xi)
|
||||||
|
y = append(y, yi)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) {
|
||||||
|
averages := make(map[string]float64)
|
||||||
|
for key, tr := range ranges {
|
||||||
|
var avg float64
|
||||||
|
// 使用GORM Raw SQL提高效率[6,10](@ref)
|
||||||
|
err := tx.Raw(`
|
||||||
|
SELECT COALESCE(AVG(value), 0) AS avg -- 关键修复
|
||||||
|
FROM heart_rates
|
||||||
|
WHERE train_id = ?
|
||||||
|
AND belt_id = ?
|
||||||
|
AND time BETWEEN ? AND ?`,
|
||||||
|
trainID, beltID, tr.Start, tr.End,
|
||||||
|
).Scan(&avg).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
averages[key] = avg
|
||||||
|
}
|
||||||
|
return averages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//func quadraticFit(x []float64, y []float64) (float64, error) {
|
||||||
|
// // 使用三点计算y=ax²+b的a值(x=[2,4,6]对应分钟)
|
||||||
|
// if len(x) != 3 || len(y) != 3 {
|
||||||
|
// return 0, errors.New("需要三个点")
|
||||||
|
// }
|
||||||
|
// // 构造方程组矩阵(简化计算)
|
||||||
|
// a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
|
||||||
|
// return a, nil
|
||||||
|
//}
|
||||||
|
|
||||||
|
func quadraticFit(x []float64, y []float64) (float64, float64, error) {
|
||||||
|
// 校验输入长度
|
||||||
|
if len(x) != 3 || len(y) != 3 {
|
||||||
|
return 0, 0, errors.New("需要三个点")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算各项累加值
|
||||||
|
var sumX4, sumX2, sumY, sumX2Y float64
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
xi := x[i]
|
||||||
|
xi2 := xi * xi
|
||||||
|
sumX4 += xi2 * xi2 // x^4累加
|
||||||
|
sumX2 += xi2 // x^2累加
|
||||||
|
sumY += y[i] // y累加
|
||||||
|
sumX2Y += xi2 * y[i] // x²y累加
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算行列式
|
||||||
|
determinant := sumX4*3 - sumX2*sumX2
|
||||||
|
if determinant == 0 {
|
||||||
|
return 0, 0, errors.New("无解,行列式为零")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算系数 a 和 b
|
||||||
|
a := (sumX2Y*3 - sumY*sumX2) / determinant
|
||||||
|
b := (sumX4*sumY - sumX2*sumX2Y) / determinant
|
||||||
|
|
||||||
|
return a, b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TimeRange struct {
|
||||||
|
Start int64 // 毫秒时间戳起点
|
||||||
|
End int64 // 毫秒时间戳终点
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTimeRanges(startTime int64) map[string]TimeRange {
|
||||||
|
// 计算相对于训练开始时间的窗口
|
||||||
|
return map[string]TimeRange{
|
||||||
|
"2min": {Start: startTime + 120000, End: startTime + 240000}, // 第2分钟(120-240秒)
|
||||||
|
"4min": {Start: startTime + 240000, End: startTime + 360000},
|
||||||
|
"6min": {Start: startTime + 360000, End: startTime + 480000},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -4,7 +4,7 @@ services:
|
|||||||
app:
|
app:
|
||||||
build: .
|
build: .
|
||||||
ports:
|
ports:
|
||||||
- "8080:8080"
|
- "8180:8080"
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
@ -27,7 +27,7 @@ services:
|
|||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "127.0.0.1:5432:5432"
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
pgdata:
|
pgdata:
|
||||||
|
|||||||
7
go.mod
7
go.mod
@ -4,9 +4,10 @@ go 1.23.3
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.1
|
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||||
|
github.com/sajari/regression v1.0.1
|
||||||
github.com/spf13/viper v1.20.0
|
github.com/spf13/viper v1.20.0
|
||||||
|
gonum.org/v1/gonum v0.16.0
|
||||||
gorm.io/driver/postgres v1.5.11
|
gorm.io/driver/postgres v1.5.11
|
||||||
gorm.io/gorm v1.25.12
|
gorm.io/gorm v1.25.12
|
||||||
)
|
)
|
||||||
@ -51,9 +52,9 @@ require (
|
|||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.32.0 // indirect
|
golang.org/x/crypto v0.32.0 // indirect
|
||||||
golang.org/x/net v0.33.0 // indirect
|
golang.org/x/net v0.33.0 // indirect
|
||||||
golang.org/x/sync v0.10.0 // indirect
|
golang.org/x/sync v0.12.0 // indirect
|
||||||
golang.org/x/sys v0.29.0 // indirect
|
golang.org/x/sys v0.29.0 // indirect
|
||||||
golang.org/x/text v0.21.0 // indirect
|
golang.org/x/text v0.23.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.1 // indirect
|
google.golang.org/protobuf v1.36.1 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
14
go.sum
14
go.sum
@ -31,8 +31,6 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx
|
|||||||
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
|
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
@ -77,6 +75,8 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
|||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
|
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
|
||||||
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
|
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
|
||||||
|
github.com/sajari/regression v1.0.1 h1:iTVc6ZACGCkoXC+8NdqH5tIreslDTT/bXxT6OmHR5PE=
|
||||||
|
github.com/sajari/regression v1.0.1/go.mod h1:NeG/XTW1lYfGY7YV/Z0nYDV/RGh3wxwd1yW46835flM=
|
||||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||||
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
|
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
|
||||||
@ -114,14 +114,16 @@ golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
|||||||
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||||
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
|
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||||
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
|
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
|
||||||
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
|||||||
11
main.go
11
main.go
@ -15,7 +15,16 @@ func main() {
|
|||||||
|
|
||||||
config.DB.Debug()
|
config.DB.Debug()
|
||||||
// 自动迁移模型
|
// 自动迁移模型
|
||||||
config.DB.AutoMigrate(&models.TrainRecord{}, &models.TrainingData{}, &models.Belt{}, &models.HeartRate{})
|
config.DB.AutoMigrate(&models.TrainRecord{},
|
||||||
|
&models.TrainingData{},
|
||||||
|
&models.Belt{},
|
||||||
|
&models.HeartRate{},
|
||||||
|
&models.BeltAnalysis{},
|
||||||
|
&models.StepTrainRecord{},
|
||||||
|
&models.StepHeartRate{},
|
||||||
|
&models.StepStrideFreq{},
|
||||||
|
&models.RegressionResult{},
|
||||||
|
)
|
||||||
|
|
||||||
// 启动服务
|
// 启动服务
|
||||||
r := routes.SetupRouter()
|
r := routes.SetupRouter()
|
||||||
|
|||||||
@ -4,10 +4,12 @@ import (
|
|||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GzipMiddleware() gin.HandlerFunc {
|
func GzipMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
// 1. 处理请求解压
|
||||||
if c.Request.Header.Get("Content-Encoding") == "gzip" {
|
if c.Request.Header.Get("Content-Encoding") == "gzip" {
|
||||||
gzReader, err := gzip.NewReader(c.Request.Body)
|
gzReader, err := gzip.NewReader(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -17,6 +19,39 @@ func GzipMiddleware() gin.HandlerFunc {
|
|||||||
defer gzReader.Close()
|
defer gzReader.Close()
|
||||||
c.Request.Body = gzReader
|
c.Request.Body = gzReader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 2. 设置响应压缩支持
|
||||||
|
if strings.Contains(c.Request.Header.Get("Accept-Encoding"), "gzip") {
|
||||||
|
// 创建gzip writer
|
||||||
|
gzWriter := gzip.NewWriter(c.Writer)
|
||||||
|
defer gzWriter.Close()
|
||||||
|
|
||||||
|
// 替换原始writer为压缩writer
|
||||||
|
originalWriter := c.Writer
|
||||||
|
c.Writer = &gzipResponseWriter{
|
||||||
|
ResponseWriter: originalWriter,
|
||||||
|
gzWriter: gzWriter,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置响应头
|
||||||
|
c.Header("Content-Encoding", "gzip")
|
||||||
|
c.Header("Vary", "Accept-Encoding")
|
||||||
|
}
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 自定义ResponseWriter实现gzip压缩
|
||||||
|
type gzipResponseWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
gzWriter *gzip.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *gzipResponseWriter) Write(data []byte) (int, error) {
|
||||||
|
return w.gzWriter.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *gzipResponseWriter) WriteString(s string) (int, error) {
|
||||||
|
return w.gzWriter.Write([]byte(s))
|
||||||
|
}
|
||||||
|
|||||||
1
models/pageination.go
Normal file
1
models/pageination.go
Normal file
@ -0,0 +1 @@
|
|||||||
|
package models
|
||||||
69
models/step_train.go
Normal file
69
models/step_train.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import "gorm.io/gorm"
|
||||||
|
|
||||||
|
type StepStrideFreq struct {
|
||||||
|
gorm.Model
|
||||||
|
TrainId uint `gorm:"column:train_id; index" json:"trainId"` // 外键关联训练记录[4](@ref)
|
||||||
|
Time int64 `gorm:"type:bigint" json:"time"` // 保持与前端一致的毫秒时间戳[3](@ref)
|
||||||
|
Value int `gorm:"type:int" json:"value"`
|
||||||
|
Count int `gorm:"type:int" json:"count"`
|
||||||
|
PredictValue int `gorm:"type:int" json:"predictValue"`
|
||||||
|
Identifier string `gorm:"uniqueIndex;type:varchar(255)" json:"identifier"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对应Flutter的HeartRate结构
|
||||||
|
type StepHeartRate struct {
|
||||||
|
gorm.Model
|
||||||
|
TrainId uint `gorm:"column:train_id; index" json:"trainId"` // 外键关联训练记录[4](@ref)
|
||||||
|
Time int64 `gorm:"type:bigint" json:"time"` // 保持与前端一致的毫秒时间戳[3](@ref)
|
||||||
|
Value int `gorm:"type:int" json:"value"`
|
||||||
|
HeartRateType int `gorm:"type:int" json:"predictValue"`
|
||||||
|
Identifier string `gorm:"uniqueIndex;type:varchar(255)" json:"identifier"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对应Flutter的TrainRecord结构
|
||||||
|
type StepTrainRecord struct {
|
||||||
|
gorm.Model
|
||||||
|
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
|
||||||
|
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
|
||||||
|
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
|
||||||
|
Name string `gorm:"size:100" json:"name"`
|
||||||
|
RunType string `gorm:"size:100" json:"runType"`
|
||||||
|
MaxHeartRate int `gorm:"type:int" json:"maxHeartRate"`
|
||||||
|
Duration int `gorm:"type:int" json:"duration"` // 持续时间(秒)
|
||||||
|
DeadZone int `gorm:"type:int" json:"deadZone"`
|
||||||
|
Evaluation string `gorm:"size:50" json:"evaluation"`
|
||||||
|
HeartRates []StepHeartRate `gorm:"foreignKey:TrainId;references:TrainId" json:"heartRates"`
|
||||||
|
StrideFreqs []StepStrideFreq `gorm:"foreignKey:TrainId;references:TrainId" json:"strideFreqs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RegressionType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
LinearRegression RegressionType = iota + 1
|
||||||
|
LogarithmicRegression
|
||||||
|
QuadraticRegression
|
||||||
|
)
|
||||||
|
|
||||||
|
type RegressionResult struct {
|
||||||
|
gorm.Model
|
||||||
|
RegressionType RegressionType `gorm:"column:regression_type;index" json:"regressionType"` // 训练记录ID
|
||||||
|
TrainId uint `gorm:"column:train_id;index" json:"trainId"` // 训练记录ID
|
||||||
|
Equation string `gorm:"type:text" json:"equation"` // 回归方程
|
||||||
|
|
||||||
|
// 线性回归系数
|
||||||
|
Slope *float64 `gorm:"column:slope" json:"slope"`
|
||||||
|
Intercept *float64 `gorm:"column:intercept" json:"intercept"`
|
||||||
|
|
||||||
|
// 对数回归系数
|
||||||
|
LogA *float64 `gorm:"column:log_a" json:"logA"`
|
||||||
|
LogB *float64 `gorm:"column:log_b" json:"logB"`
|
||||||
|
|
||||||
|
// 二次回归系数
|
||||||
|
QuadraticA *float64 `gorm:"column:quadratic_a" json:"quadraticA"`
|
||||||
|
QuadraticB *float64 `gorm:"column:quadratic_b" json:"quadraticB"`
|
||||||
|
QuadraticC *float64 `gorm:"column:quadratic_c" json:"quadraticC"`
|
||||||
|
|
||||||
|
RSquared *float64 `gorm:"column:r_squared" json:"rSquared"` // R平方值
|
||||||
|
}
|
||||||
@ -18,6 +18,21 @@ type Belt struct {
|
|||||||
Name string `gorm:"size:100" json:"name"`
|
Name string `gorm:"size:100" json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 分析结果存储实体
|
||||||
|
type BeltAnalysis struct {
|
||||||
|
gorm.Model
|
||||||
|
TrainID uint `gorm:"index;not null"` // 关联训练记录
|
||||||
|
BeltID uint `gorm:"index;not null"` // 腰带唯一标识
|
||||||
|
RunType string `gorm:"size:100" json:"RunType"`
|
||||||
|
Avg2min float64 `gorm:"type:double precision"` // 第2分钟平均心率
|
||||||
|
Avg4min float64 `gorm:"type:double precision"` // 第4分钟平均心率
|
||||||
|
Avg6min float64 `gorm:"type:double precision"` // 第6分钟平均心率
|
||||||
|
CurveParamA float64 `gorm:"type:double precision"` // 拟合参数a值
|
||||||
|
CurveParamB float64 `gorm:"type:double precision"` // 拟合参数a值
|
||||||
|
}
|
||||||
|
|
||||||
|
// 中间计算结构(无需持久化)
|
||||||
|
|
||||||
// 对应Flutter的HeartRate结构
|
// 对应Flutter的HeartRate结构
|
||||||
type HeartRate struct {
|
type HeartRate struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
@ -35,8 +50,10 @@ type TrainRecord struct {
|
|||||||
gorm.Model
|
gorm.Model
|
||||||
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
|
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
|
||||||
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
|
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
|
||||||
|
TestTime int64 `gorm:"type:bigint" json:"testTime"` // 开始时间戳
|
||||||
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
|
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
|
||||||
Name string `gorm:"size:100" json:"name"`
|
Name string `gorm:"size:100" json:"name"`
|
||||||
|
RunType string `gorm:"size:100" json:"RunType"`
|
||||||
MaxHeartRate int `gorm:"type:int" json:"maxHeartRate"`
|
MaxHeartRate int `gorm:"type:int" json:"maxHeartRate"`
|
||||||
Duration int `gorm:"type:int" json:"duration"` // 持续时间(秒)
|
Duration int `gorm:"type:int" json:"duration"` // 持续时间(秒)
|
||||||
PeopleNum int `gorm:"type:int" json:"peopleNum"`
|
PeopleNum int `gorm:"type:int" json:"peopleNum"`
|
||||||
|
|||||||
@ -12,12 +12,22 @@ func SetupRouter() *gin.Engine {
|
|||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
r.Use(middleware.GzipMiddleware())
|
r.Use(middleware.GzipMiddleware())
|
||||||
trainingController := controllers.NewTrainingController()
|
trainingController := controllers.NewTrainingController()
|
||||||
|
stepTrainController := controllers.NewStepTrainingController()
|
||||||
|
|
||||||
v1 := r.Group("/api/v1")
|
v1 := r.Group("/api/v1")
|
||||||
{
|
{
|
||||||
records := v1.Group("/train-records").Use(middleware.AuthMiddleware())
|
records := v1.Group("/train-records") //.Use(middleware.AuthMiddleware())
|
||||||
{
|
{
|
||||||
records.POST("", trainingController.CreateTrainingRecord)
|
records.POST("", trainingController.CreateTrainingRecord)
|
||||||
|
records.GET("/analysis", trainingController.HandleCurveAnalysis)
|
||||||
|
// 可扩展其他路由:GET, PUT, DELETE等
|
||||||
|
}
|
||||||
|
steps := v1.Group("/step") //.Use(middleware.AuthMiddleware())
|
||||||
|
{
|
||||||
|
steps.POST("", stepTrainController.CreateTrainingRecord)
|
||||||
|
steps.GET("train-records", stepTrainController.GetTrainingRecords)
|
||||||
|
steps.GET("train-data/:trainId", stepTrainController.GetTrainingRecordByTrainId)
|
||||||
|
steps.GET("train-rank/:trainId", stepTrainController.GetTrainingRank)
|
||||||
// 可扩展其他路由:GET, PUT, DELETE等
|
// 可扩展其他路由:GET, PUT, DELETE等
|
||||||
}
|
}
|
||||||
auth := v1.Group("/auth")
|
auth := v1.Group("/auth")
|
||||||
|
|||||||
Reference in New Issue
Block a user