refactor: data analyze result.

This commit is contained in:
2025-04-01 15:35:24 +08:00
parent 7b9e870bf9
commit 6e8232ff7f
4 changed files with 180 additions and 34 deletions

View File

@ -1,5 +1,11 @@
package controllers
import (
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/gonum/stat/distuv"
)
import (
"errors"
"github.com/gin-gonic/gin"
@ -7,9 +13,21 @@ import (
"gorm.io/gorm/clause"
"hr_receiver/config"
"hr_receiver/models"
"math"
"net/http"
)
var analyzeRunTypes = []string{"6.5开始", "7开始", "8开始"} // 替换为你的具体值
func contains(s string) bool {
for _, item := range analyzeRunTypes {
if item == s {
return true
}
}
return false
}
type TrainingController struct {
DB *gorm.DB
}
@ -58,9 +76,11 @@ func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
return err
}
}
err := tc.heartRateAnalyze(tx, record)
if err != nil {
return err
if contains(record.RunType) {
err := tc.heartRateAnalyze(tx, record)
if err != nil {
return err
}
}
return nil
@ -77,6 +97,120 @@ func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) {
})
}
// analysis_response.go
type AnalysisResponse struct {
Status string `json:"status"` // 状态码
Message string `json:"message"` // 附加信息
Data struct {
Mean float64 `json:"mean"` // 均值
StdDev float64 `json:"stdDev"` // 标准差
Histogram []HistoBin `json:"histogram"` // 直方图数据
Curve []CurvePoint `json:"curve"` // 正态曲线数据
} `json:"data"`
}
type HistoBin struct {
BinStart float64 `json:"binStart"` // 区间起始值
BinEnd float64 `json:"binEnd"` // 区间结束值
Count int `json:"count"` // 该区间计数
}
type CurvePoint struct {
X float64 `json:"x"` // X坐标
Y float64 `json:"y"` // Y坐标
}
// analysis_handler.go
func (tc *TrainingController) HandleCurveAnalysis(c *gin.Context) {
// 获取数据库连接(根据实际项目配置调整)
// 1. 获取历史数据
aValues, err := collectCurveParams(tc.DB)
if err != nil {
c.JSON(500, gin.H{
"status": "error",
"message": "数据查询失败: " + err.Error(),
})
return
}
// 2. 检查数据有效性
if len(aValues) < 10 { // 至少需要10个样本
c.JSON(400, gin.H{
"status": "fail",
"message": "数据量不足至少需要10个样本",
})
return
}
// 3. 计算统计量
mean, stddev := calculateStats(aValues)
// 4. 生成直方图数据
histogram := calculateHistogram(aValues, 20) // 20个分箱
// 5. 生成正态曲线
x, y := generateNormalCurve(mean, stddev, 100)
// 6. 构造响应
response := AnalysisResponse{
Status: "success",
Message: "分析完成",
Data: struct {
Mean float64 `json:"mean"`
StdDev float64 `json:"stdDev"`
Histogram []HistoBin `json:"histogram"`
Curve []CurvePoint `json:"curve"`
}{
Mean: mean,
StdDev: stddev,
Histogram: histogram,
Curve: convertToCurvePoints(x, y),
},
}
c.JSON(200, response)
}
// 直方图计算函数
func calculateHistogram(data []float64, bins int) []HistoBin {
minV, maxV := floats.Min(data), floats.Max(data)
binWidth := (maxV - minV) / float64(bins)
counts := make([]int, bins)
for _, v := range data {
idx := int((v - minV) / binWidth)
if idx == bins { // 处理最大值刚好等于maxV的情况
idx--
}
counts[idx]++
}
histogram := make([]HistoBin, bins)
for i := 0; i < bins; i++ {
start := minV + float64(i)*binWidth
end := minV + float64(i+1)*binWidth
histogram[i] = HistoBin{
BinStart: start,
BinEnd: end,
Count: counts[i],
}
}
return histogram
}
// 转换曲线数据格式
func convertToCurvePoints(x, y []float64) []CurvePoint {
points := make([]CurvePoint, len(x))
for i := range x {
points[i] = CurvePoint{
X: x[i],
Y: y[i],
}
}
return points
}
func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error {
startTime := record.StartTime
@ -116,36 +250,47 @@ func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainR
return nil
}
func ReceiveTrainingData(c *gin.Context) {
var data models.TrainingData
if err := c.ShouldBindJSON(&data); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid request body: " + err.Error(),
})
return
func collectCurveParams(tx *gorm.DB) ([]float64, error) {
var aValues []float64
// 查询所有记录的 CurveParamA 字段
err := tx.Model(&models.BeltAnalysis{}).Pluck("curve_param_a", &aValues).Error
if err != nil {
return nil, err
}
if result := config.DB.Create(&data); result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to save data: " + result.Error.Error(),
})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "Data saved successfully",
"id": data.ID,
})
return aValues, nil
}
func calculateStats(data []float64) (mean, stddev float64) {
mean = stat.Mean(data, nil)
variance := stat.Variance(data, nil)
stddev = math.Sqrt(variance)
return
}
func generateNormalCurve(mean, stddev float64, numPoints int) (x, y []float64) {
normal := distuv.Normal{
Mu: mean,
Sigma: stddev,
}
minV := mean - 3*stddev // 从均值-3σ开始
maxV := mean + 3*stddev // 到均值+3σ结束
step := (maxV - minV) / float64(numPoints-1)
for i := 0; i < numPoints; i++ {
xi := minV + float64(i)*step
yi := normal.Prob(xi)
x = append(x, xi)
y = append(y, yi)
}
return
}
func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) {
averages := make(map[string]float64)
for key, tr := range ranges {
var avg float64
// 使用GORM Raw SQL提高效率[6,10](@ref)
err := tx.Raw(`
SELECT AVG(value)
SELECT COALESCE(AVG(value), 0) AS avg -- 关键修复
FROM heart_rates
WHERE train_id = ?
AND belt_id = ?