Files
hr_data_analyzer/controllers/train.go
T
2026-05-04 16:20:46 +08:00

542 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package controllers
import (
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"hr_receiver/config"
"hr_receiver/models"
"math"
"net/http"
)
var analyzeRunTypes = []string{"6.5开始", "7开始", "8开始"} // 替换为你的具体值
func contains(s string) bool {
for _, item := range analyzeRunTypes {
if item == s {
return true
}
}
return false
}
type TrainingController struct {
DB *gorm.DB
}
func NewTrainingController() *TrainingController {
return &TrainingController{DB: config.DB}
}
// @Summary 创建训练记录
// @Description 接收并保存训练记录及心率数据支持重复上传按train_id去重更新
// @Tags 训练管理
// @Accept json
// @Produce json
// @Param record body SwagAPIResponse true "训练记录数据"
// @Success 201 {object} SwagAPIResponse "保存成功"
// @Failure 400 {object} SwagAPIResponse "请求参数错误"
// @Router /train-records [post]
func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
var record models.TrainRecord
// 绑定并验证JSON数据
if err := c.ShouldBindJSON(&record); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 使用事务保存数据[4](@ref)
err := tc.DB.Transaction(func(tx *gorm.DB) error {
// 保存主记录
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{
"max_heart_rate": record.MaxHeartRate,
"start_time": record.StartTime,
"end_time": record.EndTime,
"duration": record.Duration,
"people_num": record.PeopleNum,
"name": record.Name,
"evaluation": record.Evaluation,
}),
}).Omit("HeartRates", "belts").Create(&record).Error; err != nil {
return err
}
// 保存关联的心率数据
for i := range record.HeartRates {
if err := tx.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
},
).Create(&record.HeartRates[i]).Error; err != nil {
return err
}
}
if contains(record.RunType) {
err := tc.heartRateAnalyze(tx, record)
if err != nil {
return err
}
}
return nil
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "数据保存成功",
"id": record.TrainId,
})
}
type trainingSessionRequest struct {
Tid int `json:"tid"`
Time int64 `json:"time"`
TestTime int64 `json:"testTime"`
EndTime int64 `json:"endTime"`
Name string `json:"name"`
RunType string `json:"runType"`
Gender string `json:"gender"`
Age int `json:"age"`
MaxHeartRate int `json:"maxHeartRate"`
Duration int `json:"duration"`
PeopleNum int `json:"peopleNum"`
Evaluation string `json:"evaluation"`
AiResult string `json:"aiResult"`
IsStart bool `json:"isStart"`
RegionID uint32 `json:"regionId"`
AppName string `json:"appName"`
}
// @Summary 上传训练会话
// @Description 上传训练开始/结束会话用于MQTT训练会话追踪
// @Tags 训练管理
// @Accept json
// @Produce json
// @Param session body trainingSessionRequest true "训练会话数据"
// @Success 200 {object} SwagAPIResponse "操作成功"
// @Failure 400 {object} SwagAPIResponse "请求参数错误"
// @Router /train-records/session [post]
func (tc *TrainingController) UploadTrainingSession(c *gin.Context) {
var req trainingSessionRequest
if err := c.ShouldBindJSON(&req); err != nil {
writeError(c, http.StatusBadRequest, err.Error())
return
}
trainID := fmt.Sprintf("%d", req.Time)
now := time.Now().UnixMilli()
flavorType := "heartrate"
identifier := schema.NamingStrategy{}.IndexName(
"mqtt_training_session",
fmt.Sprintf("%s_%d_%s", flavorType, req.RegionID, trainID),
)
rawPayload, _ := json.Marshal(req)
if req.IsStart {
record := &models.MqttTrainingSessionRecord{
Identifier: identifier,
TestID: trainID,
EventType: "start_test",
RegionID: req.RegionID,
FlavorType: flavorType,
RawFlavor: "hr",
AppName: req.AppName,
TrainId: trainID,
PeopleNum: req.PeopleNum,
StartedAt: &req.Time,
PublishedAt: now,
ReceivedAt: now,
RawPayload: string(rawPayload),
}
err := tc.DB.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}},
DoUpdates: clause.Assignments(map[string]interface{}{"started_at": *record.StartedAt, "train_id": record.TrainId, "updated_at": time.Now()}),
}).Create(record).Error
if err != nil {
writeError(c, http.StatusInternalServerError, "failed to save session")
return
}
writeSuccess(c, http.StatusOK, "session start registered", nil)
return
}
var existing models.MqttTrainingSessionRecord
err := tc.DB.Where("train_id = ?", trainID).First(&existing).Error
if err != nil {
record := &models.MqttTrainingSessionRecord{
Identifier: identifier,
TestID: trainID,
EventType: "stop_test",
RegionID: req.RegionID,
FlavorType: flavorType,
RawFlavor: "hr",
AppName: req.AppName,
TrainId: trainID,
PeopleNum: req.PeopleNum,
StartedAt: &req.Time,
EndedAt: &req.EndTime,
PublishedAt: now,
ReceivedAt: now,
RawPayload: string(rawPayload),
}
err = tc.DB.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}},
DoUpdates: clause.Assignments(map[string]interface{}{
"ended_at": *record.EndedAt,
"people_num": record.PeopleNum,
"event_type": record.EventType,
"updated_at": time.Now(),
}),
}).Create(record).Error
if err != nil {
writeError(c, http.StatusInternalServerError, "failed to create session")
return
}
writeSuccess(c, http.StatusOK, "session created", nil)
return
}
updates := map[string]interface{}{
"event_type": "stop_test",
"ended_at": req.EndTime,
"people_num": req.PeopleNum,
"received_at": now,
"raw_payload": string(rawPayload),
"updated_at": time.Now(),
}
if err := tc.DB.Model(&existing).Updates(updates).Error; err != nil {
writeError(c, http.StatusInternalServerError, "failed to update session")
return
}
writeSuccess(c, http.StatusOK, "session updated", nil)
}
type cloudLessonPlanItem struct {
ID uint `json:"id"`
OriginalFilename string `json:"originalFilename"`
FileSize int64 `json:"fileSize"`
UploaderName string `json:"uploaderName"`
}
// @Summary 获取云端教案列表
// @Description 获取所有云端教案文件列表
// @Tags 训练管理
// @Produce json
// @Success 200 {object} SwagAPIResponse "查询成功"
// @Router /train-records/cloud-files [get]
func (tc *TrainingController) ListCloudLessonPlans(c *gin.Context) {
var records []models.AppFile
if err := tc.DB.Where("file_type = ?", models.AppFileTypeLessonPlan).Order("created_at DESC").Find(&records).Error; err != nil {
writeError(c, http.StatusInternalServerError, "failed to list lesson plans")
return
}
items := make([]cloudLessonPlanItem, 0, len(records))
for _, r := range records {
items = append(items, cloudLessonPlanItem{
ID: r.ID,
OriginalFilename: r.OriginalFilename,
FileSize: r.FileSize,
UploaderName: r.UploaderName,
})
}
writeSuccess(c, http.StatusOK, "query success", items)
}
// analysis_response.go
type AnalysisResponse struct {
Status string `json:"status"` // 状态码
Message string `json:"message"` // 附加信息
Data struct {
Mean float64 `json:"mean"` // 均值
StdDev float64 `json:"stdDev"` // 标准差
Histogram []HistoBin `json:"histogram"` // 直方图数据
Curve []CurvePoint `json:"curve"` // 正态曲线数据
} `json:"data"`
}
type HistoBin struct {
BinStart float64 `json:"binStart"` // 区间起始值
BinEnd float64 `json:"binEnd"` // 区间结束值
Count int `json:"count"` // 该区间计数
}
type CurvePoint struct {
X float64 `json:"x"` // X坐标
Y float64 `json:"y"` // Y坐标
}
// @Summary 心率曲线分析
// @Description 对历史心率数据进行统计分析和正态分布曲线拟合
// @Tags 训练管理
// @Produce json
// @Success 200 {object} SwagAPIResponse "分析成功"
// @Failure 400 {object} SwagAPIResponse "数据量不足"
// @Router /train-records/analysis [get]
func (tc *TrainingController) HandleCurveAnalysis(c *gin.Context) {
// 获取数据库连接(根据实际项目配置调整)
// 1. 获取历史数据
aValues, err := collectCurveParams(tc.DB)
if err != nil {
c.JSON(500, gin.H{
"status": "error",
"message": "数据查询失败: " + err.Error(),
})
return
}
// 2. 检查数据有效性
if len(aValues) < 10 { // 至少需要10个样本
c.JSON(400, gin.H{
"status": "fail",
"message": "数据量不足至少需要10个样本",
})
return
}
// 3. 计算统计量
mean, stddev := calculateStats(aValues)
// 4. 生成直方图数据
histogram := calculateHistogram(aValues, 20) // 20个分箱
// 5. 生成正态曲线
x, y := generateNormalCurve(mean, stddev, 100)
// 6. 构造响应
response := AnalysisResponse{
Status: "success",
Message: "分析完成",
Data: struct {
Mean float64 `json:"mean"`
StdDev float64 `json:"stdDev"`
Histogram []HistoBin `json:"histogram"`
Curve []CurvePoint `json:"curve"`
}{
Mean: mean,
StdDev: stddev,
Histogram: histogram,
Curve: convertToCurvePoints(x, y),
},
}
c.JSON(200, response)
}
// 直方图计算函数
func calculateHistogram(data []float64, bins int) []HistoBin {
minV, maxV := floats.Min(data), floats.Max(data)
binWidth := (maxV - minV) / float64(bins)
counts := make([]int, bins)
for _, v := range data {
idx := int((v - minV) / binWidth)
if idx == bins { // 处理最大值刚好等于maxV的情况
idx--
}
counts[idx]++
}
histogram := make([]HistoBin, bins)
for i := 0; i < bins; i++ {
start := minV + float64(i)*binWidth
end := minV + float64(i+1)*binWidth
histogram[i] = HistoBin{
BinStart: start,
BinEnd: end,
Count: counts[i],
}
}
return histogram
}
// 转换曲线数据格式
func convertToCurvePoints(x, y []float64) []CurvePoint {
points := make([]CurvePoint, len(x))
for i := range x {
points[i] = CurvePoint{
X: x[i],
Y: y[i],
}
}
return points
}
func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error {
var startTime int64
if record.TestTime > 0 {
startTime = record.TestTime
} else {
startTime = record.StartTime
}
// 获取所有唯一的beltID
var beltIDs []uint
tx.Model(&models.HeartRate{}).Where("train_id = ?", record.TrainId).
Select("DISTINCT belt_id").Pluck("belt_id", &beltIDs)
// 对每个belt计算
for _, bid := range beltIDs {
// 计算平均心率
ranges := getTimeRanges(startTime)
averages, err := calculateAverages(tx, record.TrainId, bid, ranges)
if err != nil {
return err
}
// 曲线拟合
x := []float64{2, 4, 6}
y := []float64{averages["2min"], averages["4min"], averages["6min"]}
a, b, _ := quadraticFit(x, y)
// 存储结果
analysis := models.BeltAnalysis{
TrainID: record.TrainId,
RunType: record.RunType,
BeltID: bid,
Avg2min: averages["2min"],
Avg4min: averages["4min"],
Avg6min: averages["6min"],
CurveParamA: a,
CurveParamB: b,
}
if err := tx.Create(&analysis).Error; err != nil {
return err
}
}
return nil
}
func collectCurveParams(tx *gorm.DB) ([]float64, error) {
var aValues []float64
// 查询所有记录的 CurveParamA 字段
err := tx.Model(&models.BeltAnalysis{}).Pluck("curve_param_a", &aValues).Error
if err != nil {
return nil, err
}
return aValues, nil
}
func calculateStats(data []float64) (mean, stddev float64) {
mean = stat.Mean(data, nil)
variance := stat.Variance(data, nil)
stddev = math.Sqrt(variance)
return
}
func generateNormalCurve(mean, stddev float64, numPoints int) (x, y []float64) {
normal := distuv.Normal{
Mu: mean,
Sigma: stddev,
}
minV := mean - 3*stddev // 从均值-3σ开始
maxV := mean + 3*stddev // 到均值+3σ结束
step := (maxV - minV) / float64(numPoints-1)
for i := 0; i < numPoints; i++ {
xi := minV + float64(i)*step
yi := normal.Prob(xi)
x = append(x, xi)
y = append(y, yi)
}
return
}
func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) {
averages := make(map[string]float64)
for key, tr := range ranges {
var avg float64
// 使用GORM Raw SQL提高效率[6,10](@ref)
err := tx.Raw(`
SELECT COALESCE(AVG(value), 0) AS avg -- 关键修复
FROM heart_rates
WHERE train_id = ?
AND belt_id = ?
AND time BETWEEN ? AND ?`,
trainID, beltID, tr.Start, tr.End,
).Scan(&avg).Error
if err != nil {
return nil, err
}
averages[key] = avg
}
return averages, nil
}
//func quadraticFit(x []float64, y []float64) (float64, error) {
// // 使用三点计算y=ax²+b的a值x=[2,4,6]对应分钟)
// if len(x) != 3 || len(y) != 3 {
// return 0, errors.New("需要三个点")
// }
// // 构造方程组矩阵(简化计算)
// a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
// return a, nil
//}
func quadraticFit(x []float64, y []float64) (float64, float64, error) {
// 校验输入长度
if len(x) != 3 || len(y) != 3 {
return 0, 0, errors.New("需要三个点")
}
// 计算各项累加值
var sumX4, sumX2, sumY, sumX2Y float64
for i := 0; i < 3; i++ {
xi := x[i]
xi2 := xi * xi
sumX4 += xi2 * xi2 // x^4累加
sumX2 += xi2 // x^2累加
sumY += y[i] // y累加
sumX2Y += xi2 * y[i] // x²y累加
}
// 计算行列式
determinant := sumX4*3 - sumX2*sumX2
if determinant == 0 {
return 0, 0, errors.New("无解,行列式为零")
}
// 计算系数 a 和 b
a := (sumX2Y*3 - sumY*sumX2) / determinant
b := (sumX4*sumY - sumX2*sumX2Y) / determinant
return a, b, nil
}
type TimeRange struct {
Start int64 // 毫秒时间戳起点
End int64 // 毫秒时间戳终点
}
func getTimeRanges(startTime int64) map[string]TimeRange {
// 计算相对于训练开始时间的窗口
return map[string]TimeRange{
"2min": {Start: startTime + 120000, End: startTime + 240000}, // 第2分钟120-240秒
"4min": {Start: startTime + 240000, End: startTime + 360000},
"6min": {Start: startTime + 360000, End: startTime + 480000},
}
}