331 lines
8.2 KiB
Go
331 lines
8.2 KiB
Go
package controllers
|
||
|
||
import (
|
||
"gonum.org/v1/gonum/floats"
|
||
"gonum.org/v1/gonum/stat"
|
||
"gonum.org/v1/gonum/stat/distuv"
|
||
)
|
||
|
||
import (
|
||
"errors"
|
||
"github.com/gin-gonic/gin"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/clause"
|
||
"hr_receiver/config"
|
||
"hr_receiver/models"
|
||
"math"
|
||
"net/http"
|
||
)
|
||
|
||
var analyzeRunTypes = []string{"6.5开始", "7开始", "8开始"} // 替换为你的具体值
|
||
|
||
func contains(s string) bool {
|
||
for _, item := range analyzeRunTypes {
|
||
if item == s {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
type TrainingController struct {
|
||
DB *gorm.DB
|
||
}
|
||
|
||
func NewTrainingController() *TrainingController {
|
||
return &TrainingController{DB: config.DB}
|
||
}
|
||
|
||
// 接收训练记录
|
||
func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
|
||
var record models.TrainRecord
|
||
|
||
// 绑定并验证JSON数据
|
||
if err := c.ShouldBindJSON(&record); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 使用事务保存数据[4](@ref)
|
||
err := tc.DB.Transaction(func(tx *gorm.DB) error {
|
||
// 保存主记录
|
||
if err := tx.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
|
||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||
"max_heart_rate": record.MaxHeartRate,
|
||
"start_time": record.StartTime,
|
||
"end_time": record.EndTime,
|
||
"duration": record.Duration,
|
||
"people_num": record.PeopleNum,
|
||
"name": record.Name,
|
||
"evaluation": record.Evaluation,
|
||
}),
|
||
}).Omit("HeartRates", "belts").Create(&record).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 保存关联的心率数据
|
||
for i := range record.HeartRates {
|
||
if err := tx.Clauses(
|
||
clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
|
||
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
|
||
},
|
||
).Create(&record.HeartRates[i]).Error; err != nil {
|
||
|
||
return err
|
||
}
|
||
}
|
||
if contains(record.RunType) {
|
||
err := tc.heartRateAnalyze(tx, record)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusCreated, gin.H{
|
||
"message": "数据保存成功",
|
||
"id": record.TrainId,
|
||
})
|
||
}
|
||
|
||
// analysis_response.go
|
||
type AnalysisResponse struct {
|
||
Status string `json:"status"` // 状态码
|
||
Message string `json:"message"` // 附加信息
|
||
Data struct {
|
||
Mean float64 `json:"mean"` // 均值
|
||
StdDev float64 `json:"stdDev"` // 标准差
|
||
Histogram []HistoBin `json:"histogram"` // 直方图数据
|
||
Curve []CurvePoint `json:"curve"` // 正态曲线数据
|
||
} `json:"data"`
|
||
}
|
||
|
||
type HistoBin struct {
|
||
BinStart float64 `json:"binStart"` // 区间起始值
|
||
BinEnd float64 `json:"binEnd"` // 区间结束值
|
||
Count int `json:"count"` // 该区间计数
|
||
}
|
||
|
||
type CurvePoint struct {
|
||
X float64 `json:"x"` // X坐标
|
||
Y float64 `json:"y"` // Y坐标
|
||
}
|
||
|
||
// analysis_handler.go
|
||
func (tc *TrainingController) HandleCurveAnalysis(c *gin.Context) {
|
||
// 获取数据库连接(根据实际项目配置调整)
|
||
|
||
// 1. 获取历史数据
|
||
aValues, err := collectCurveParams(tc.DB)
|
||
if err != nil {
|
||
c.JSON(500, gin.H{
|
||
"status": "error",
|
||
"message": "数据查询失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
// 2. 检查数据有效性
|
||
if len(aValues) < 10 { // 至少需要10个样本
|
||
c.JSON(400, gin.H{
|
||
"status": "fail",
|
||
"message": "数据量不足,至少需要10个样本",
|
||
})
|
||
return
|
||
}
|
||
|
||
// 3. 计算统计量
|
||
mean, stddev := calculateStats(aValues)
|
||
|
||
// 4. 生成直方图数据
|
||
histogram := calculateHistogram(aValues, 20) // 20个分箱
|
||
|
||
// 5. 生成正态曲线
|
||
x, y := generateNormalCurve(mean, stddev, 100)
|
||
|
||
// 6. 构造响应
|
||
response := AnalysisResponse{
|
||
Status: "success",
|
||
Message: "分析完成",
|
||
Data: struct {
|
||
Mean float64 `json:"mean"`
|
||
StdDev float64 `json:"stdDev"`
|
||
Histogram []HistoBin `json:"histogram"`
|
||
Curve []CurvePoint `json:"curve"`
|
||
}{
|
||
Mean: mean,
|
||
StdDev: stddev,
|
||
Histogram: histogram,
|
||
Curve: convertToCurvePoints(x, y),
|
||
},
|
||
}
|
||
|
||
c.JSON(200, response)
|
||
}
|
||
|
||
// 直方图计算函数
|
||
func calculateHistogram(data []float64, bins int) []HistoBin {
|
||
minV, maxV := floats.Min(data), floats.Max(data)
|
||
binWidth := (maxV - minV) / float64(bins)
|
||
|
||
counts := make([]int, bins)
|
||
for _, v := range data {
|
||
idx := int((v - minV) / binWidth)
|
||
if idx == bins { // 处理最大值刚好等于maxV的情况
|
||
idx--
|
||
}
|
||
counts[idx]++
|
||
}
|
||
|
||
histogram := make([]HistoBin, bins)
|
||
for i := 0; i < bins; i++ {
|
||
start := minV + float64(i)*binWidth
|
||
end := minV + float64(i+1)*binWidth
|
||
histogram[i] = HistoBin{
|
||
BinStart: start,
|
||
BinEnd: end,
|
||
Count: counts[i],
|
||
}
|
||
}
|
||
return histogram
|
||
}
|
||
|
||
// 转换曲线数据格式
|
||
func convertToCurvePoints(x, y []float64) []CurvePoint {
|
||
points := make([]CurvePoint, len(x))
|
||
for i := range x {
|
||
points[i] = CurvePoint{
|
||
X: x[i],
|
||
Y: y[i],
|
||
}
|
||
}
|
||
return points
|
||
}
|
||
|
||
func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error {
|
||
startTime := record.StartTime
|
||
|
||
// 获取所有唯一的beltID
|
||
var beltIDs []uint
|
||
tx.Model(&models.HeartRate{}).Where("train_id = ?", record.TrainId).
|
||
Select("DISTINCT belt_id").Pluck("belt_id", &beltIDs)
|
||
|
||
// 对每个belt计算
|
||
for _, bid := range beltIDs {
|
||
// 计算平均心率
|
||
ranges := getTimeRanges(startTime)
|
||
averages, err := calculateAverages(tx, record.TrainId, bid, ranges)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 曲线拟合
|
||
x := []float64{2, 4, 6}
|
||
y := []float64{averages["2min"], averages["4min"], averages["6min"]}
|
||
a, _ := quadraticFit(x, y)
|
||
|
||
// 存储结果
|
||
analysis := models.BeltAnalysis{
|
||
TrainID: record.TrainId,
|
||
RunType: record.RunType,
|
||
BeltID: bid,
|
||
Avg2min: averages["2min"],
|
||
Avg4min: averages["4min"],
|
||
Avg6min: averages["6min"],
|
||
CurveParamA: a,
|
||
}
|
||
if err := tx.Create(&analysis).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func collectCurveParams(tx *gorm.DB) ([]float64, error) {
|
||
var aValues []float64
|
||
// 查询所有记录的 CurveParamA 字段
|
||
err := tx.Model(&models.BeltAnalysis{}).Pluck("curve_param_a", &aValues).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return aValues, nil
|
||
}
|
||
func calculateStats(data []float64) (mean, stddev float64) {
|
||
mean = stat.Mean(data, nil)
|
||
variance := stat.Variance(data, nil)
|
||
stddev = math.Sqrt(variance)
|
||
return
|
||
}
|
||
|
||
func generateNormalCurve(mean, stddev float64, numPoints int) (x, y []float64) {
|
||
normal := distuv.Normal{
|
||
Mu: mean,
|
||
Sigma: stddev,
|
||
}
|
||
|
||
minV := mean - 3*stddev // 从均值-3σ开始
|
||
maxV := mean + 3*stddev // 到均值+3σ结束
|
||
step := (maxV - minV) / float64(numPoints-1)
|
||
|
||
for i := 0; i < numPoints; i++ {
|
||
xi := minV + float64(i)*step
|
||
yi := normal.Prob(xi)
|
||
x = append(x, xi)
|
||
y = append(y, yi)
|
||
}
|
||
return
|
||
}
|
||
func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) {
|
||
averages := make(map[string]float64)
|
||
for key, tr := range ranges {
|
||
var avg float64
|
||
// 使用GORM Raw SQL提高效率[6,10](@ref)
|
||
err := tx.Raw(`
|
||
SELECT COALESCE(AVG(value), 0) AS avg -- 关键修复
|
||
FROM heart_rates
|
||
WHERE train_id = ?
|
||
AND belt_id = ?
|
||
AND time BETWEEN ? AND ?`,
|
||
trainID, beltID, tr.Start, tr.End,
|
||
).Scan(&avg).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
averages[key] = avg
|
||
}
|
||
return averages, nil
|
||
}
|
||
|
||
func quadraticFit(x []float64, y []float64) (float64, error) {
|
||
// 使用三点计算y=ax²+b的a值(x=[2,4,6]对应分钟)
|
||
if len(x) != 3 || len(y) != 3 {
|
||
return 0, errors.New("需要三个点")
|
||
}
|
||
// 构造方程组矩阵(简化计算)
|
||
a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
|
||
return a, nil
|
||
}
|
||
|
||
type TimeRange struct {
|
||
Start int64 // 毫秒时间戳起点
|
||
End int64 // 毫秒时间戳终点
|
||
}
|
||
|
||
func getTimeRanges(startTime int64) map[string]TimeRange {
|
||
// 计算相对于训练开始时间的窗口
|
||
return map[string]TimeRange{
|
||
"2min": {Start: startTime + 120000, End: startTime + 240000}, // 第2分钟(120-240秒)
|
||
"4min": {Start: startTime + 240000, End: startTime + 360000},
|
||
"6min": {Start: startTime + 360000, End: startTime + 480000},
|
||
}
|
||
}
|