Files
hr_data_analyzer/controllers/train.go

331 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package controllers
import (
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
import (
"errors"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"hr_receiver/config"
"hr_receiver/models"
"math"
"net/http"
)
var analyzeRunTypes = []string{"6.5开始", "7开始", "8开始"} // 替换为你的具体值
func contains(s string) bool {
for _, item := range analyzeRunTypes {
if item == s {
return true
}
}
return false
}
type TrainingController struct {
DB *gorm.DB
}
func NewTrainingController() *TrainingController {
return &TrainingController{DB: config.DB}
}
// 接收训练记录
func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
var record models.TrainRecord
// 绑定并验证JSON数据
if err := c.ShouldBindJSON(&record); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 使用事务保存数据[4](@ref)
err := tc.DB.Transaction(func(tx *gorm.DB) error {
// 保存主记录
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{
"max_heart_rate": record.MaxHeartRate,
"start_time": record.StartTime,
"end_time": record.EndTime,
"duration": record.Duration,
"people_num": record.PeopleNum,
"name": record.Name,
"evaluation": record.Evaluation,
}),
}).Omit("HeartRates", "belts").Create(&record).Error; err != nil {
return err
}
// 保存关联的心率数据
for i := range record.HeartRates {
if err := tx.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
},
).Create(&record.HeartRates[i]).Error; err != nil {
return err
}
}
if contains(record.RunType) {
err := tc.heartRateAnalyze(tx, record)
if err != nil {
return err
}
}
return nil
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "数据保存成功",
"id": record.TrainId,
})
}
// analysis_response.go
type AnalysisResponse struct {
Status string `json:"status"` // 状态码
Message string `json:"message"` // 附加信息
Data struct {
Mean float64 `json:"mean"` // 均值
StdDev float64 `json:"stdDev"` // 标准差
Histogram []HistoBin `json:"histogram"` // 直方图数据
Curve []CurvePoint `json:"curve"` // 正态曲线数据
} `json:"data"`
}
type HistoBin struct {
BinStart float64 `json:"binStart"` // 区间起始值
BinEnd float64 `json:"binEnd"` // 区间结束值
Count int `json:"count"` // 该区间计数
}
type CurvePoint struct {
X float64 `json:"x"` // X坐标
Y float64 `json:"y"` // Y坐标
}
// analysis_handler.go
func (tc *TrainingController) HandleCurveAnalysis(c *gin.Context) {
// 获取数据库连接(根据实际项目配置调整)
// 1. 获取历史数据
aValues, err := collectCurveParams(tc.DB)
if err != nil {
c.JSON(500, gin.H{
"status": "error",
"message": "数据查询失败: " + err.Error(),
})
return
}
// 2. 检查数据有效性
if len(aValues) < 10 { // 至少需要10个样本
c.JSON(400, gin.H{
"status": "fail",
"message": "数据量不足至少需要10个样本",
})
return
}
// 3. 计算统计量
mean, stddev := calculateStats(aValues)
// 4. 生成直方图数据
histogram := calculateHistogram(aValues, 20) // 20个分箱
// 5. 生成正态曲线
x, y := generateNormalCurve(mean, stddev, 100)
// 6. 构造响应
response := AnalysisResponse{
Status: "success",
Message: "分析完成",
Data: struct {
Mean float64 `json:"mean"`
StdDev float64 `json:"stdDev"`
Histogram []HistoBin `json:"histogram"`
Curve []CurvePoint `json:"curve"`
}{
Mean: mean,
StdDev: stddev,
Histogram: histogram,
Curve: convertToCurvePoints(x, y),
},
}
c.JSON(200, response)
}
// 直方图计算函数
func calculateHistogram(data []float64, bins int) []HistoBin {
minV, maxV := floats.Min(data), floats.Max(data)
binWidth := (maxV - minV) / float64(bins)
counts := make([]int, bins)
for _, v := range data {
idx := int((v - minV) / binWidth)
if idx == bins { // 处理最大值刚好等于maxV的情况
idx--
}
counts[idx]++
}
histogram := make([]HistoBin, bins)
for i := 0; i < bins; i++ {
start := minV + float64(i)*binWidth
end := minV + float64(i+1)*binWidth
histogram[i] = HistoBin{
BinStart: start,
BinEnd: end,
Count: counts[i],
}
}
return histogram
}
// 转换曲线数据格式
func convertToCurvePoints(x, y []float64) []CurvePoint {
points := make([]CurvePoint, len(x))
for i := range x {
points[i] = CurvePoint{
X: x[i],
Y: y[i],
}
}
return points
}
func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error {
startTime := record.StartTime
// 获取所有唯一的beltID
var beltIDs []uint
tx.Model(&models.HeartRate{}).Where("train_id = ?", record.TrainId).
Select("DISTINCT belt_id").Pluck("belt_id", &beltIDs)
// 对每个belt计算
for _, bid := range beltIDs {
// 计算平均心率
ranges := getTimeRanges(startTime)
averages, err := calculateAverages(tx, record.TrainId, bid, ranges)
if err != nil {
return err
}
// 曲线拟合
x := []float64{2, 4, 6}
y := []float64{averages["2min"], averages["4min"], averages["6min"]}
a, _ := quadraticFit(x, y)
// 存储结果
analysis := models.BeltAnalysis{
TrainID: record.TrainId,
RunType: record.RunType,
BeltID: bid,
Avg2min: averages["2min"],
Avg4min: averages["4min"],
Avg6min: averages["6min"],
CurveParamA: a,
}
if err := tx.Create(&analysis).Error; err != nil {
return err
}
}
return nil
}
func collectCurveParams(tx *gorm.DB) ([]float64, error) {
var aValues []float64
// 查询所有记录的 CurveParamA 字段
err := tx.Model(&models.BeltAnalysis{}).Pluck("curve_param_a", &aValues).Error
if err != nil {
return nil, err
}
return aValues, nil
}
func calculateStats(data []float64) (mean, stddev float64) {
mean = stat.Mean(data, nil)
variance := stat.Variance(data, nil)
stddev = math.Sqrt(variance)
return
}
func generateNormalCurve(mean, stddev float64, numPoints int) (x, y []float64) {
normal := distuv.Normal{
Mu: mean,
Sigma: stddev,
}
minV := mean - 3*stddev // 从均值-3σ开始
maxV := mean + 3*stddev // 到均值+3σ结束
step := (maxV - minV) / float64(numPoints-1)
for i := 0; i < numPoints; i++ {
xi := minV + float64(i)*step
yi := normal.Prob(xi)
x = append(x, xi)
y = append(y, yi)
}
return
}
func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) {
averages := make(map[string]float64)
for key, tr := range ranges {
var avg float64
// 使用GORM Raw SQL提高效率[6,10](@ref)
err := tx.Raw(`
SELECT COALESCE(AVG(value), 0) AS avg -- 关键修复
FROM heart_rates
WHERE train_id = ?
AND belt_id = ?
AND time BETWEEN ? AND ?`,
trainID, beltID, tr.Start, tr.End,
).Scan(&avg).Error
if err != nil {
return nil, err
}
averages[key] = avg
}
return averages, nil
}
func quadraticFit(x []float64, y []float64) (float64, error) {
// 使用三点计算y=ax²+b的a值x=[2,4,6]对应分钟)
if len(x) != 3 || len(y) != 3 {
return 0, errors.New("需要三个点")
}
// 构造方程组矩阵(简化计算)
a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
return a, nil
}
type TimeRange struct {
Start int64 // 毫秒时间戳起点
End int64 // 毫秒时间戳终点
}
func getTimeRanges(startTime int64) map[string]TimeRange {
// 计算相对于训练开始时间的窗口
return map[string]TimeRange{
"2min": {Start: startTime + 120000, End: startTime + 240000}, // 第2分钟120-240秒
"4min": {Start: startTime + 240000, End: startTime + 360000},
"6min": {Start: startTime + 360000, End: startTime + 480000},
}
}