Compare commits

...

19 Commits

Author SHA1 Message Date
1a252a12be feat: count. 2025-10-22 15:25:42 +08:00
8d8dd26a2c feat: calculate when upload. 2025-08-05 10:55:19 +08:00
a14c553736 refactor: port. 2025-08-05 10:50:39 +08:00
563fdd8a7e fix: rank. 2025-08-05 10:43:19 +08:00
d3e87fda67 fix: linear slope. 2025-08-05 10:18:35 +08:00
3078c13e14 fix: sort. 2025-08-05 09:55:37 +08:00
7a2a44e327 feat: rank query 2025-08-04 11:18:07 +08:00
b00767dfcd feat: regression 2025-08-04 10:46:44 +08:00
2aa22b1385 fix: json 2025-07-03 09:38:51 +08:00
07963218cd fix: type error 2025-06-27 10:26:23 +08:00
0852f4bc23 feat: gzip response. 2025-06-27 09:46:04 +08:00
19148d7d35 feat: query. 2025-06-27 09:31:42 +08:00
4e10359a5b feat: step train. 2025-06-25 09:36:27 +08:00
8510d19baa refactor: train. 2025-06-25 08:53:08 +08:00
f4df1724b2 fix: port map. 2025-04-08 16:17:49 +08:00
5cca15808d feat: run bash. 2025-04-07 11:14:46 +08:00
d69aaa5704 refactor: config.sample.yaml. 2025-04-07 09:48:49 +08:00
6e8232ff7f refactor: data analyze result. 2025-04-01 15:35:24 +08:00
7b9e870bf9 refactor: data analyze. 2025-04-01 10:13:38 +08:00
14 changed files with 1348 additions and 43 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea
hr_receiver.iml
main.go.bak
config.yaml

895
controllers/step_train.go Normal file
View File

@ -0,0 +1,895 @@
package controllers
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/sajari/regression"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"hr_receiver/config"
"hr_receiver/models"
"log"
"math"
"net/http"
"strconv"
"strings"
)
type StepTrainingController struct {
DB *gorm.DB
}
func NewStepTrainingController() *StepTrainingController {
return &StepTrainingController{DB: config.DB}
}
// 接收训练记录
func (tc *StepTrainingController) CreateTrainingRecord(c *gin.Context) {
var record models.StepTrainRecord
// 绑定并验证JSON数据
if err := c.ShouldBindJSON(&record); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 使用事务保存数据[4](@ref)
err := tc.DB.Transaction(func(tx *gorm.DB) error {
// 保存主记录
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{
"max_heart_rate": record.MaxHeartRate,
"start_time": record.StartTime,
"end_time": record.EndTime,
"duration": record.Duration,
"dead_zone": record.DeadZone,
"name": record.Name,
"evaluation": record.Evaluation,
}),
}).Omit("HeartRates", "StrideFreqs").Create(&record).Error; err != nil {
return err
}
// 保存关联的心率数据
for i := range record.HeartRates {
if err := tx.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{"heart_rate_type": record.HeartRates[i].HeartRateType, "value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
},
).Create(&record.HeartRates[i]).Error; err != nil {
return err
}
}
for i := range record.StrideFreqs {
if err := tx.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.StrideFreqs[i].Value, "time": record.StrideFreqs[i].Time}),
},
).Create(&record.StrideFreqs[i]).Error; err != nil {
return err
}
}
return nil
})
// ====== 新增部分:启动异步回归计算 ======
go func() {
// 查询完整数据(需要关联的心率和步频数据)
var fullRecord models.StepTrainRecord
if err := tc.DB.
Where("train_id = ?", record.TrainId).
Preload("HeartRates", "heart_rate_type = ?", 1). // 只要有效心率
Preload("StrideFreqs", "predict_value = ?", 1). // 只要有效步频
First(&fullRecord).Error; err != nil {
log.Printf("训练记录%d查询失败无法计算回归: %v", record.TrainId, err)
return
}
// 检查数据是否满足计算条件
if len(fullRecord.HeartRates) == 0 || len(fullRecord.StrideFreqs) == 0 {
log.Printf("训练记录%d缺少心率或步频数据跳过回归计算", record.TrainId)
return
}
// 计算并保存回归结果
if _, err := tc.GetOrCalculateRegression(fullRecord.TrainId); err != nil {
log.Printf("训练记录%d回归计算失败: %v", fullRecord.TrainId, err)
} else {
log.Printf("训练记录%d回归结果已保存", fullRecord.TrainId)
}
}()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "数据保存成功",
"id": record.TrainId,
})
}
func (tc *StepTrainingController) GetTrainingRecords(c *gin.Context) {
// 定义分页参数结构
type PaginationParams struct {
PageNum int `form:"pageNum,default=1"` // 页码,默认第一页
PageSize int `form:"pageSize,default=10"` // 每页数量默认10条
}
var params PaginationParams
if err := c.ShouldBindQuery(&params); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 验证分页参数有效性
if params.PageNum < 1 {
params.PageNum = 1
}
if params.PageSize < 1 || params.PageSize > 100 {
params.PageSize = 10
}
// 计算偏移量
offset := (params.PageNum - 1) * params.PageSize
var (
records []models.StepTrainRecord
totalRows int64
)
// 获取总记录数
if err := tc.DB.Model(&models.StepTrainRecord{}).Count(&totalRows).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取记录总数失败"})
return
}
// 查询分页数据(按开始时间倒序排列)
result := tc.DB.
Order("start_time DESC"). // 按开始时间倒序
Offset(offset).
Limit(params.PageSize).
Find(&records)
if result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
return
}
// 计算总页数
totalPages := int(math.Ceil(float64(totalRows) / float64(params.PageSize)))
c.JSON(http.StatusOK, gin.H{
"message": "查询成功",
"data": gin.H{
"list": records,
"pagination": gin.H{
"currentPage": params.PageNum,
"pageSize": params.PageSize,
"totalPage": totalPages,
"totalList": totalRows,
},
},
})
}
func (tc *StepTrainingController) GetTrainingRecordByTrainId(c *gin.Context) {
// 从URL路径参数获取trainId
trainId := c.Param("trainId")
if trainId == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "训练ID不能为空"})
return
}
// 将字符串trainId转换为uint类型
tid, err := strconv.ParseInt(trainId, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID格式"})
return
}
var record models.StepTrainRecord
// 查询主记录并预加载关联的心率和步频数据
result := tc.DB.Where("train_id = ?", uint(tid)).
Preload("HeartRates").
Preload("StrideFreqs").
First(&record)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": "训练记录不存在"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
}
return
}
// 成功返回数据
c.JSON(http.StatusOK, gin.H{
"message": "查询成功",
"data": record,
})
}
// 定义结构体
type SpeedSegment struct {
Duration float64
Speed float64
}
// 实现线性回归算法
func performLinearRegression(averages []map[float64]float64) models.RegressionResult {
if len(averages) == 0 {
return models.RegressionResult{
Equation: "无数据",
}
}
// 收集数据点
var points []struct{ x, y float64 }
for _, m := range averages {
for x, y := range m {
points = append(points, struct{ x, y float64 }{x, y})
}
}
// 使用回归库计算
r := new(regression.Regression)
r.SetObserved("y")
r.SetVar(0, "x")
for _, p := range points {
r.Train(regression.DataPoint(p.y, []float64{p.x}))
}
if err := r.Run(); err != nil {
log.Printf("线性回归计算失败: %v", err)
return models.RegressionResult{
Equation: "计算失败",
}
}
// 创建结果
slope := r.Coeff(1)
intercept := r.Coeff(0)
r2 := r.R2
return models.RegressionResult{
RegressionType: models.LinearRegression,
Slope: &slope,
Intercept: &intercept,
RSquared: &r2,
Equation: r.Formula,
}
}
// 实现对数和二次回归算法
// 对数回归算法
func performLogarithmicRegression(averages []map[float64]float64) models.RegressionResult {
if len(averages) == 0 {
return models.RegressionResult{
Equation: "无数据",
}
}
// 收集数据点
r := new(regression.Regression)
r.SetObserved("y")
r.SetVar(0, "log(x+1)")
for _, m := range averages {
for speed, hr := range m {
logSpeed := math.Log(speed + 1)
r.Train(regression.DataPoint(hr, []float64{logSpeed}))
}
}
if err := r.Run(); err != nil {
log.Printf("对数回归计算失败: %v", err)
return models.RegressionResult{
Equation: "计算失败",
}
}
// 创建结果
logA := r.Coeff(1)
logB := r.Coeff(0)
r2 := r.R2
return models.RegressionResult{
RegressionType: models.LogarithmicRegression,
LogA: &logA,
LogB: &logB,
RSquared: &r2,
Equation: r.Formula,
}
}
// 二次回归算法
//func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
// if len(averages) == 0 {
// return models.RegressionResult{
// Equation: "无数据",
// }
// }
//
// // 收集数据点
// r := new(regression.Regression)
// r.SetObserved("y")
// r.SetVar(0, "x")
// r.SetVar(1, "x²")
//
// for _, m := range averages {
// for speed, hr := range m {
// speedSq := math.Pow(speed, 2)
// r.Train(regression.DataPoint(hr, []float64{speed, speedSq}))
// }
// }
//
// if err := r.Run(); err != nil {
// log.Printf("二次回归计算失败: %v", err)
// return models.RegressionResult{
// Equation: "计算失败",
// }
// }
//
// // 创建结果
// a := r.Coeff(2)
// b := r.Coeff(1)
// c := r.Coeff(0)
// r2 := r.R2
// return models.RegressionResult{
// RegressionType: models.QuadraticRegression,
// QuadraticA: &a,
// QuadraticB: &b,
// QuadraticC: &c,
// RSquared: &r2,
// Equation: r.Formula,
// }
//}
func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
if len(averages) == 0 {
return models.RegressionResult{
Equation: "无数据",
}
}
// 步骤1收集所有数据点与Flutter一致
var xValues []float64
var yValues []float64
for _, m := range averages {
for speed, hr := range m {
xValues = append(xValues, speed)
yValues = append(yValues, hr)
}
}
n := float64(len(xValues))
// 步骤2计算各项和完全匹配Flutter的计算
var sumX, sumY, sumX2, sumX3, sumX4, sumXY, sumX2Y float64
for i := 0; i < len(xValues); i++ {
x := xValues[i]
y := yValues[i]
x2 := x * x
x3 := x2 * x
x4 := x3 * x
sumX += x
sumY += y
sumX2 += x2
sumX3 += x3
sumX4 += x4
sumXY += x * y
sumX2Y += x2 * y
}
// 步骤3构建正规方程矩阵与Flutter完全一致
matrix := [3][3]float64{
{n, sumX, sumX2},
{sumX, sumX2, sumX3},
{sumX2, sumX3, sumX4},
}
vector := []float64{sumY, sumXY, sumX2Y}
// 步骤4计算矩阵行列式复制Flutter的determinant3x3逻辑
det := matrix[0][0]*(matrix[1][1]*matrix[2][2]-matrix[1][2]*matrix[2][1]) -
matrix[0][1]*(matrix[1][0]*matrix[2][2]-matrix[1][2]*matrix[2][0]) +
matrix[0][2]*(matrix[1][0]*matrix[2][1]-matrix[1][1]*matrix[2][0])
if det == 0 {
return models.RegressionResult{
Equation: "无法拟合",
}
}
// 步骤5克莱姆法则求解系数顺序与Flutter一致
// 注意:最终系数顺序 a=二次项, b=一次项, c=常数项
c := det3x3([3][3]float64{
{vector[0], matrix[0][1], matrix[0][2]},
{vector[1], matrix[1][1], matrix[1][2]},
{vector[2], matrix[2][1], matrix[2][2]},
}) / det
b := det3x3([3][3]float64{
{matrix[0][0], vector[0], matrix[0][2]},
{matrix[1][0], vector[1], matrix[1][2]},
{matrix[2][0], vector[2], matrix[2][2]},
}) / det
a := det3x3([3][3]float64{
{matrix[0][0], matrix[0][1], vector[0]},
{matrix[1][0], matrix[1][1], vector[1]},
{matrix[2][0], matrix[2][1], vector[2]},
}) / det
// 步骤6计算R平方完全复制Flutter的计算逻辑
var ssRes, ssTot float64
meanY := sumY / n
for i := 0; i < len(xValues); i++ {
x := xValues[i]
y := yValues[i]
yPred := a*x*x + b*x + c
ssRes += math.Pow(y-yPred, 2)
ssTot += math.Pow(y-meanY, 2)
}
rSquared := 0.0
if ssTot != 0 {
rSquared = 1 - ssRes/ssTot
}
// 步骤7格式化公式字符串与Flutter格式完全一致
equation := formatEquation(a, b, c, rSquared)
return models.RegressionResult{
RegressionType: models.QuadraticRegression,
QuadraticA: &a,
QuadraticB: &b,
QuadraticC: &c,
RSquared: &rSquared,
Equation: equation,
}
}
// 3x3行列式计算与Flutter实现相同
func det3x3(m [3][3]float64) float64 {
return m[0][0]*(m[1][1]*m[2][2]-m[1][2]*m[2][1]) -
m[0][1]*(m[1][0]*m[2][2]-m[1][2]*m[2][0]) +
m[0][2]*(m[1][0]*m[2][1]-m[1][1]*m[2][0])
}
// 公式格式化完全匹配Flutter格式
func formatEquation(a, b, c, r2 float64) string {
// 保留4位小数
aStr := fmt.Sprintf("%.4f", a)
bStr := fmt.Sprintf("%.4f", b)
cStr := fmt.Sprintf("%.4f", c)
r2Str := fmt.Sprintf("%.4f", r2)
builder := strings.Builder{}
builder.WriteString("y = ")
// 处理二次项
if a >= 0 {
builder.WriteString(aStr + " x²")
} else {
builder.WriteString("-" + strings.TrimPrefix(aStr, "-") + " x²")
}
// 处理一次项
if b >= 0 {
builder.WriteString(" + " + bStr + " x")
} else {
builder.WriteString(" - " + strings.TrimPrefix(bStr, "-") + " x")
}
// 处理常数项
if c >= 0 {
builder.WriteString(" + " + cStr)
} else {
builder.WriteString(" - " + strings.TrimPrefix(cStr, "-"))
}
builder.WriteString(" (R² = " + r2Str + ")")
return builder.String()
}
// 步频数据转换为速度段
func convertStrideFrequencyToSegments(steps []models.StepStrideFreq) []SpeedSegment {
if len(steps) == 0 {
return []SpeedSegment{}
}
// 过滤零值并排序
validSteps := make([]models.StepStrideFreq, 0, len(steps))
for _, s := range steps {
if s.Value > 0 {
validSteps = append(validSteps, s)
}
}
if len(validSteps) == 0 {
return []SpeedSegment{}
}
// 按时间排序
for i := 0; i < len(validSteps)-1; i++ {
for j := i + 1; j < len(validSteps); j++ {
if validSteps[i].Time > validSteps[j].Time {
validSteps[i], validSteps[j] = validSteps[j], validSteps[i]
}
}
}
// 创建速度段
segments := make([]SpeedSegment, 0)
startTime := validSteps[0].Time
currentValue := validSteps[0].Value
for i := 1; i < len(validSteps); i++ {
if validSteps[i].Value != currentValue {
duration := float64(validSteps[i].Time-startTime) / 1000.0
if duration > 0 {
segments = append(segments, SpeedSegment{
Duration: duration,
Speed: float64(currentValue),
})
}
startTime = validSteps[i].Time
currentValue = validSteps[i].Value
}
}
// 添加最后一个段
if len(validSteps) > 0 {
duration := float64(validSteps[len(validSteps)-1].Time-startTime) / 1000.0
if duration > 0 {
segments = append(segments, SpeedSegment{
Duration: duration,
Speed: float64(currentValue),
})
}
}
return segments
}
// 计算区段平均值
func calculateSegmentAverages(heartRates []models.StepHeartRate, segments []SpeedSegment, errorThreshold int) []map[float64]float64 {
currentTime := 0.0
results := make([]map[float64]float64, 0)
for _, seg := range segments {
minRequired := 60 + (60 - float64(errorThreshold))
// 跳过不满足条件的区段
if seg.Duration < minRequired {
currentTime += seg.Duration
continue
}
// 计算时间窗口
startSec := currentTime + 60
endSec := currentTime
if seg.Duration >= 120 {
endSec = currentTime + 120
} else {
endSec = currentTime + 120 - float64(errorThreshold)
}
// 收集该区段的心率数据
sum, count := 0, 0
for _, hr := range heartRates {
sec := float64(hr.Time) / 1000.0
if sec >= startSec && sec <= endSec {
sum += hr.Value
count++
}
}
// 计算平均值
if count > 0 {
avg := float64(sum) / float64(count)
results = append(results, map[float64]float64{seg.Speed: avg})
}
currentTime += seg.Duration
}
return results
}
// 计算步频区段的心率平均值
func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps []models.StepStrideFreq) []map[float64]float64 {
segments := convertStrideFrequencyToSegments(steps)
return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值
}
// 存储回归结果到数据库(支持多种回归类型)
func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error {
return tc.DB.Transaction(func(tx *gorm.DB) error {
for i := range results {
results[i].TrainId = trainId
// 使用复合唯一约束确保每种回归类型只存储一条记录
err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"equation": results[i].Equation,
"slope": results[i].Slope,
"intercept": results[i].Intercept,
"log_a": results[i].LogA,
"log_b": results[i].LogB,
"quadratic_a": results[i].QuadraticA,
"quadratic_b": results[i].QuadraticB,
"quadratic_c": results[i].QuadraticC,
"r_squared": results[i].RSquared,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&results[i]).Error
if err != nil {
return err
}
}
return nil
})
}
// 获取或计算回归结果(返回多种回归类型列表)
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) {
// 尝试从数据库获取所有类型的回归结果
var results []models.RegressionResult
err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error
// 如果已存在三种类型的结果,直接返回
if err == nil && len(results) >= 3 {
return results, nil
}
// 查询训练记录及相关数据
var record models.StepTrainRecord
if err := tc.DB.
Where("train_id = ?", uint(trainId)).
Preload("HeartRates", "heart_rate_type = ?", 1).
Preload("StrideFreqs", "predict_value = ?", 1).
First(&record).Error; err != nil {
return nil, err
}
// 计算心率平均值
averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs)
if len(averages) == 0 {
return nil, errors.New("无足够数据进行回归计算")
}
// 创建三种回归类型的结果
results = make([]models.RegressionResult, 3)
// 线性回归
linearRes := performLinearRegression(averages)
results[0] = models.RegressionResult{
RegressionType: models.LinearRegression,
TrainId: trainId,
Equation: linearRes.Equation,
Slope: linearRes.Slope,
Intercept: linearRes.Intercept,
RSquared: linearRes.RSquared,
}
// 对数回归
logRes := performLogarithmicRegression(averages)
results[1] = models.RegressionResult{
RegressionType: models.LogarithmicRegression,
TrainId: trainId,
Equation: logRes.Equation,
LogA: logRes.LogA,
LogB: logRes.LogB,
RSquared: logRes.RSquared,
}
// 二次回归
quadRes := performQuadraticRegression(averages)
results[2] = models.RegressionResult{
RegressionType: models.QuadraticRegression,
TrainId: trainId,
Equation: quadRes.Equation,
QuadraticA: quadRes.QuadraticA,
QuadraticB: quadRes.QuadraticB,
QuadraticC: quadRes.QuadraticC,
RSquared: quadRes.RSquared,
}
// 批量保存结果到数据库
if err := tc.SaveRegressionResults(trainId, results); err != nil {
log.Printf("保存回归结果失败: %v", err)
return nil, err
}
return results, nil
}
// 新增接口:获取回归结果
func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) {
trainIdStr := c.Param("trainId")
tid, err := strconv.ParseUint(trainIdStr, 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
return
}
result, err := tc.GetOrCalculateRegression(uint(tid))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "获取成功",
"data": result,
})
}
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
// 参数解析
trainIdStr := c.Param("trainId")
regressionTypeStr := c.Query("type")
regressionType, err := strconv.Atoi(regressionTypeStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
return
}
// 验证回归类型
regType := models.RegressionType(regressionType)
if regType != models.LinearRegression && regType != models.QuadraticRegression {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
return
}
// 转换训练ID
tid, err := strconv.ParseUint(trainIdStr, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
return
}
trainId := uint(tid)
// 确保回归结果存在
if _, err := tc.GetOrCalculateRegression(trainId); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
return
}
// 获取指定训练的基准值
var baseValue float64
baseQuery := tc.DB.Model(&models.RegressionResult{}).
Select(getValueColumn(regType)).
Where("train_id = ?", trainId)
if err := baseQuery.Row().Scan(&baseValue); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": "指定的训练数据不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取基准数据失败"})
return
}
// 在数据库中进行排名计算
var rank struct {
BetterCount int64
Total int64
}
// 动态生成比较条件
betterCondition := fmt.Sprintf("%s %s ?",
getValueColumn(regType),
getComparisonOperator(regType))
totalQuery := tc.DB.Model(&models.RegressionResult{}).
Where(getTypeCondition(regType))
if err := totalQuery.Count(&rank.Total).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "统计总数失败"})
return
}
if err := tc.DB.Model(&models.RegressionResult{}).
Where(getTypeCondition(regType)).
Where(betterCondition, baseValue).
Count(&rank.BetterCount).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "计算排名失败"})
return
}
// 计算实际排名 (并列排名)
currentRank := rank.BetterCount + 1
// 返回响应
c.JSON(http.StatusOK, gin.H{
"message": "排名查询成功",
"data": gin.H{
"trainId": trainId,
"type": regressionType,
"rank": currentRank,
"total": rank.Total,
},
})
}
// 辅助函数:获取排序字段名
func getValueColumn(regType models.RegressionType) string {
switch regType {
case models.LinearRegression:
return "slope"
case models.QuadraticRegression:
return "ABS(quadratic_a)" // 计算绝对值
default:
return ""
}
}
// 辅助函数:获取比较操作符
func getComparisonOperator(regType models.RegressionType) string {
switch regType {
case models.LinearRegression:
return "<" // 线性回归:值越小越好
case models.QuadraticRegression:
return ">" // 二次回归:绝对值越大越好
default:
return ""
}
}
// 辅助函数:获取类型条件
func getTypeCondition(regType models.RegressionType) string {
switch regType {
case models.LinearRegression:
return "slope IS NOT NULL"
case models.QuadraticRegression:
return "quadratic_a IS NOT NULL AND quadratic_a < 0" // 确保是负值
default:
return ""
}
}
// 辅助函数:比较浮点指针(用于线性回归)
func compareFloatPtr(a, b *float64, ascending bool) bool {
if a == nil && b == nil {
return false
}
if a == nil {
return false // 空值排最后
}
if b == nil {
return true // 非空值排前
}
if ascending {
return *a < *b
}
return *a > *b
}
// 辅助函数:比较二次项系数(用于二次回归)
func compareQuadraticA(a, b *float64) bool {
if a == nil && b == nil {
return false
}
if a == nil {
return false
}
if b == nil {
return true
}
// 比较绝对值a和b都是负值所以取绝对值后大的排前面
return math.Abs(*a) > math.Abs(*b)
}

View File

@ -1,14 +1,33 @@
package controllers
import (
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
import (
"errors"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"hr_receiver/config"
"hr_receiver/models"
"math"
"net/http"
)
var analyzeRunTypes = []string{"6.5开始", "7开始", "8开始"} // 替换为你的具体值
func contains(s string) bool {
for _, item := range analyzeRunTypes {
if item == s {
return true
}
}
return false
}
type TrainingController struct {
DB *gorm.DB
}
@ -57,13 +76,12 @@ func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
return err
}
}
//// 保存腰带关联关系
//if len(record.Belts) > 0 {
// if err := tx.Model(&record).Association("Belts").Replace(record.Belts); err != nil {
// return err
// }
//}
if contains(record.RunType) {
err := tc.heartRateAnalyze(tx, record)
if err != nil {
return err
}
}
return nil
})
@ -79,25 +97,270 @@ func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
})
}
func ReceiveTrainingData(c *gin.Context) {
var data models.TrainingData
if err := c.ShouldBindJSON(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid request body: " + err.Error(),
})
return
}
if result := config.DB.Create(&data); result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to save data: " + result.Error.Error(),
})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "Data saved successfully",
"id": data.ID,
})
// analysis_response.go
type AnalysisResponse struct {
Status string `json:"status"` // 状态码
Message string `json:"message"` // 附加信息
Data struct {
Mean float64 `json:"mean"` // 均值
StdDev float64 `json:"stdDev"` // 标准差
Histogram []HistoBin `json:"histogram"` // 直方图数据
Curve []CurvePoint `json:"curve"` // 正态曲线数据
} `json:"data"`
}
type HistoBin struct {
BinStart float64 `json:"binStart"` // 区间起始值
BinEnd float64 `json:"binEnd"` // 区间结束值
Count int `json:"count"` // 该区间计数
}
type CurvePoint struct {
X float64 `json:"x"` // X坐标
Y float64 `json:"y"` // Y坐标
}
// analysis_handler.go
func (tc *TrainingController) HandleCurveAnalysis(c *gin.Context) {
// 获取数据库连接(根据实际项目配置调整)
// 1. 获取历史数据
aValues, err := collectCurveParams(tc.DB)
if err != nil {
c.JSON(500, gin.H{
"status": "error",
"message": "数据查询失败: " + err.Error(),
})
return
}
// 2. 检查数据有效性
if len(aValues) < 10 { // 至少需要10个样本
c.JSON(400, gin.H{
"status": "fail",
"message": "数据量不足至少需要10个样本",
})
return
}
// 3. 计算统计量
mean, stddev := calculateStats(aValues)
// 4. 生成直方图数据
histogram := calculateHistogram(aValues, 20) // 20个分箱
// 5. 生成正态曲线
x, y := generateNormalCurve(mean, stddev, 100)
// 6. 构造响应
response := AnalysisResponse{
Status: "success",
Message: "分析完成",
Data: struct {
Mean float64 `json:"mean"`
StdDev float64 `json:"stdDev"`
Histogram []HistoBin `json:"histogram"`
Curve []CurvePoint `json:"curve"`
}{
Mean: mean,
StdDev: stddev,
Histogram: histogram,
Curve: convertToCurvePoints(x, y),
},
}
c.JSON(200, response)
}
// 直方图计算函数
func calculateHistogram(data []float64, bins int) []HistoBin {
minV, maxV := floats.Min(data), floats.Max(data)
binWidth := (maxV - minV) / float64(bins)
counts := make([]int, bins)
for _, v := range data {
idx := int((v - minV) / binWidth)
if idx == bins { // 处理最大值刚好等于maxV的情况
idx--
}
counts[idx]++
}
histogram := make([]HistoBin, bins)
for i := 0; i < bins; i++ {
start := minV + float64(i)*binWidth
end := minV + float64(i+1)*binWidth
histogram[i] = HistoBin{
BinStart: start,
BinEnd: end,
Count: counts[i],
}
}
return histogram
}
// 转换曲线数据格式
func convertToCurvePoints(x, y []float64) []CurvePoint {
points := make([]CurvePoint, len(x))
for i := range x {
points[i] = CurvePoint{
X: x[i],
Y: y[i],
}
}
return points
}
func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error {
var startTime int64
if record.TestTime > 0 {
startTime = record.TestTime
} else {
startTime = record.StartTime
}
// 获取所有唯一的beltID
var beltIDs []uint
tx.Model(&models.HeartRate{}).Where("train_id = ?", record.TrainId).
Select("DISTINCT belt_id").Pluck("belt_id", &beltIDs)
// 对每个belt计算
for _, bid := range beltIDs {
// 计算平均心率
ranges := getTimeRanges(startTime)
averages, err := calculateAverages(tx, record.TrainId, bid, ranges)
if err != nil {
return err
}
// 曲线拟合
x := []float64{2, 4, 6}
y := []float64{averages["2min"], averages["4min"], averages["6min"]}
a, b, _ := quadraticFit(x, y)
// 存储结果
analysis := models.BeltAnalysis{
TrainID: record.TrainId,
RunType: record.RunType,
BeltID: bid,
Avg2min: averages["2min"],
Avg4min: averages["4min"],
Avg6min: averages["6min"],
CurveParamA: a,
CurveParamB: b,
}
if err := tx.Create(&analysis).Error; err != nil {
return err
}
}
return nil
}
func collectCurveParams(tx *gorm.DB) ([]float64, error) {
var aValues []float64
// 查询所有记录的 CurveParamA 字段
err := tx.Model(&models.BeltAnalysis{}).Pluck("curve_param_a", &aValues).Error
if err != nil {
return nil, err
}
return aValues, nil
}
func calculateStats(data []float64) (mean, stddev float64) {
mean = stat.Mean(data, nil)
variance := stat.Variance(data, nil)
stddev = math.Sqrt(variance)
return
}
func generateNormalCurve(mean, stddev float64, numPoints int) (x, y []float64) {
normal := distuv.Normal{
Mu: mean,
Sigma: stddev,
}
minV := mean - 3*stddev // 从均值-3σ开始
maxV := mean + 3*stddev // 到均值+3σ结束
step := (maxV - minV) / float64(numPoints-1)
for i := 0; i < numPoints; i++ {
xi := minV + float64(i)*step
yi := normal.Prob(xi)
x = append(x, xi)
y = append(y, yi)
}
return
}
func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) {
averages := make(map[string]float64)
for key, tr := range ranges {
var avg float64
// 使用GORM Raw SQL提高效率[6,10](@ref)
err := tx.Raw(`
SELECT COALESCE(AVG(value), 0) AS avg -- 关键修复
FROM heart_rates
WHERE train_id = ?
AND belt_id = ?
AND time BETWEEN ? AND ?`,
trainID, beltID, tr.Start, tr.End,
).Scan(&avg).Error
if err != nil {
return nil, err
}
averages[key] = avg
}
return averages, nil
}
//func quadraticFit(x []float64, y []float64) (float64, error) {
// // 使用三点计算y=ax²+b的a值x=[2,4,6]对应分钟)
// if len(x) != 3 || len(y) != 3 {
// return 0, errors.New("需要三个点")
// }
// // 构造方程组矩阵(简化计算)
// a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
// return a, nil
//}
func quadraticFit(x []float64, y []float64) (float64, float64, error) {
// 校验输入长度
if len(x) != 3 || len(y) != 3 {
return 0, 0, errors.New("需要三个点")
}
// 计算各项累加值
var sumX4, sumX2, sumY, sumX2Y float64
for i := 0; i < 3; i++ {
xi := x[i]
xi2 := xi * xi
sumX4 += xi2 * xi2 // x^4累加
sumX2 += xi2 // x^2累加
sumY += y[i] // y累加
sumX2Y += xi2 * y[i] // x²y累加
}
// 计算行列式
determinant := sumX4*3 - sumX2*sumX2
if determinant == 0 {
return 0, 0, errors.New("无解,行列式为零")
}
// 计算系数 a 和 b
a := (sumX2Y*3 - sumY*sumX2) / determinant
b := (sumX4*sumY - sumX2*sumX2Y) / determinant
return a, b, nil
}
type TimeRange struct {
Start int64 // 毫秒时间戳起点
End int64 // 毫秒时间戳终点
}
func getTimeRanges(startTime int64) map[string]TimeRange {
// 计算相对于训练开始时间的窗口
return map[string]TimeRange{
"2min": {Start: startTime + 120000, End: startTime + 240000}, // 第2分钟120-240秒
"4min": {Start: startTime + 240000, End: startTime + 360000},
"6min": {Start: startTime + 360000, End: startTime + 480000},
}
}

View File

@ -4,7 +4,7 @@ services:
app:
build: .
ports:
- "8080:8180"
- "8180:8080"
depends_on:
db:
condition: service_healthy

7
go.mod
View File

@ -4,9 +4,10 @@ go 1.23.3
require (
github.com/gin-gonic/gin v1.10.0
github.com/golang-jwt/jwt/v4 v4.5.1
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/sajari/regression v1.0.1
github.com/spf13/viper v1.20.0
gonum.org/v1/gonum v0.16.0
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.25.12
)
@ -51,9 +52,9 @@ require (
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.32.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sync v0.12.0 // indirect
golang.org/x/sys v0.29.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.36.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

14
go.sum
View File

@ -31,8 +31,6 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
@ -77,6 +75,8 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
github.com/sajari/regression v1.0.1 h1:iTVc6ZACGCkoXC+8NdqH5tIreslDTT/bXxT6OmHR5PE=
github.com/sajari/regression v1.0.1/go.mod h1:NeG/XTW1lYfGY7YV/Z0nYDV/RGh3wxwd1yW46835flM=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
@ -114,14 +114,16 @@ golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

11
main.go
View File

@ -15,7 +15,16 @@ func main() {
config.DB.Debug()
// 自动迁移模型
config.DB.AutoMigrate(&models.TrainRecord{}, &models.TrainingData{}, &models.Belt{}, &models.HeartRate{})
config.DB.AutoMigrate(&models.TrainRecord{},
&models.TrainingData{},
&models.Belt{},
&models.HeartRate{},
&models.BeltAnalysis{},
&models.StepTrainRecord{},
&models.StepHeartRate{},
&models.StepStrideFreq{},
&models.RegressionResult{},
)
// 启动服务
r := routes.SetupRouter()

View File

@ -4,10 +4,12 @@ import (
"compress/gzip"
"github.com/gin-gonic/gin"
"net/http"
"strings"
)
func GzipMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 1. 处理请求解压
if c.Request.Header.Get("Content-Encoding") == "gzip" {
gzReader, err := gzip.NewReader(c.Request.Body)
if err != nil {
@ -17,6 +19,39 @@ func GzipMiddleware() gin.HandlerFunc {
defer gzReader.Close()
c.Request.Body = gzReader
}
// 2. 设置响应压缩支持
if strings.Contains(c.Request.Header.Get("Accept-Encoding"), "gzip") {
// 创建gzip writer
gzWriter := gzip.NewWriter(c.Writer)
defer gzWriter.Close()
// 替换原始writer为压缩writer
originalWriter := c.Writer
c.Writer = &gzipResponseWriter{
ResponseWriter: originalWriter,
gzWriter: gzWriter,
}
// 设置响应头
c.Header("Content-Encoding", "gzip")
c.Header("Vary", "Accept-Encoding")
}
c.Next()
}
}
// 自定义ResponseWriter实现gzip压缩
type gzipResponseWriter struct {
gin.ResponseWriter
gzWriter *gzip.Writer
}
func (w *gzipResponseWriter) Write(data []byte) (int, error) {
return w.gzWriter.Write(data)
}
func (w *gzipResponseWriter) WriteString(s string) (int, error) {
return w.gzWriter.Write([]byte(s))
}

1
models/pageination.go Normal file
View File

@ -0,0 +1 @@
package models

69
models/step_train.go Normal file
View File

@ -0,0 +1,69 @@
package models
import "gorm.io/gorm"
type StepStrideFreq struct {
gorm.Model
TrainId uint `gorm:"column:train_id; index" json:"trainId"` // 外键关联训练记录[4](@ref)
Time int64 `gorm:"type:bigint" json:"time"` // 保持与前端一致的毫秒时间戳[3](@ref)
Value int `gorm:"type:int" json:"value"`
Count int `gorm:"type:int" json:"count"`
PredictValue int `gorm:"type:int" json:"predictValue"`
Identifier string `gorm:"uniqueIndex;type:varchar(255)" json:"identifier"`
}
// 对应Flutter的HeartRate结构
type StepHeartRate struct {
gorm.Model
TrainId uint `gorm:"column:train_id; index" json:"trainId"` // 外键关联训练记录[4](@ref)
Time int64 `gorm:"type:bigint" json:"time"` // 保持与前端一致的毫秒时间戳[3](@ref)
Value int `gorm:"type:int" json:"value"`
HeartRateType int `gorm:"type:int" json:"predictValue"`
Identifier string `gorm:"uniqueIndex;type:varchar(255)" json:"identifier"`
}
// 对应Flutter的TrainRecord结构
type StepTrainRecord struct {
gorm.Model
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
Name string `gorm:"size:100" json:"name"`
RunType string `gorm:"size:100" json:"runType"`
MaxHeartRate int `gorm:"type:int" json:"maxHeartRate"`
Duration int `gorm:"type:int" json:"duration"` // 持续时间(秒)
DeadZone int `gorm:"type:int" json:"deadZone"`
Evaluation string `gorm:"size:50" json:"evaluation"`
HeartRates []StepHeartRate `gorm:"foreignKey:TrainId;references:TrainId" json:"heartRates"`
StrideFreqs []StepStrideFreq `gorm:"foreignKey:TrainId;references:TrainId" json:"strideFreqs"`
}
type RegressionType int
const (
LinearRegression RegressionType = iota + 1
LogarithmicRegression
QuadraticRegression
)
type RegressionResult struct {
gorm.Model
RegressionType RegressionType `gorm:"column:regression_type;index" json:"regressionType"` // 训练记录ID
TrainId uint `gorm:"column:train_id;index" json:"trainId"` // 训练记录ID
Equation string `gorm:"type:text" json:"equation"` // 回归方程
// 线性回归系数
Slope *float64 `gorm:"column:slope" json:"slope"`
Intercept *float64 `gorm:"column:intercept" json:"intercept"`
// 对数回归系数
LogA *float64 `gorm:"column:log_a" json:"logA"`
LogB *float64 `gorm:"column:log_b" json:"logB"`
// 二次回归系数
QuadraticA *float64 `gorm:"column:quadratic_a" json:"quadraticA"`
QuadraticB *float64 `gorm:"column:quadratic_b" json:"quadraticB"`
QuadraticC *float64 `gorm:"column:quadratic_c" json:"quadraticC"`
RSquared *float64 `gorm:"column:r_squared" json:"rSquared"` // R平方值
}

View File

@ -18,6 +18,21 @@ type Belt struct {
Name string `gorm:"size:100" json:"name"`
}
// 分析结果存储实体
type BeltAnalysis struct {
gorm.Model
TrainID uint `gorm:"index;not null"` // 关联训练记录
BeltID uint `gorm:"index;not null"` // 腰带唯一标识
RunType string `gorm:"size:100" json:"RunType"`
Avg2min float64 `gorm:"type:double precision"` // 第2分钟平均心率
Avg4min float64 `gorm:"type:double precision"` // 第4分钟平均心率
Avg6min float64 `gorm:"type:double precision"` // 第6分钟平均心率
CurveParamA float64 `gorm:"type:double precision"` // 拟合参数a值
CurveParamB float64 `gorm:"type:double precision"` // 拟合参数a值
}
// 中间计算结构(无需持久化)
// 对应Flutter的HeartRate结构
type HeartRate struct {
gorm.Model
@ -35,8 +50,10 @@ type TrainRecord struct {
gorm.Model
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
TestTime int64 `gorm:"type:bigint" json:"testTime"` // 开始时间戳
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
Name string `gorm:"size:100" json:"name"`
RunType string `gorm:"size:100" json:"RunType"`
MaxHeartRate int `gorm:"type:int" json:"maxHeartRate"`
Duration int `gorm:"type:int" json:"duration"` // 持续时间(秒)
PeopleNum int `gorm:"type:int" json:"peopleNum"`

View File

@ -12,12 +12,22 @@ func SetupRouter() *gin.Engine {
r := gin.Default()
r.Use(middleware.GzipMiddleware())
trainingController := controllers.NewTrainingController()
stepTrainController := controllers.NewStepTrainingController()
v1 := r.Group("/api/v1")
{
records := v1.Group("/train-records").Use(middleware.AuthMiddleware())
records := v1.Group("/train-records") //.Use(middleware.AuthMiddleware())
{
records.POST("", trainingController.CreateTrainingRecord)
records.GET("/analysis", trainingController.HandleCurveAnalysis)
// 可扩展其他路由GET, PUT, DELETE等
}
steps := v1.Group("/step") //.Use(middleware.AuthMiddleware())
{
steps.POST("", stepTrainController.CreateTrainingRecord)
steps.GET("train-records", stepTrainController.GetTrainingRecords)
steps.GET("train-data/:trainId", stepTrainController.GetTrainingRecordByTrainId)
steps.GET("train-rank/:trainId", stepTrainController.GetTrainingRank)
// 可扩展其他路由GET, PUT, DELETE等
}
auth := v1.Group("/auth")

2
run.sh Normal file
View File

@ -0,0 +1,2 @@
#!/bin/bash
docker-compose up --build -d