diff --git a/controllers/train.go b/controllers/train.go index 4323997..19df33b 100644 --- a/controllers/train.go +++ b/controllers/train.go @@ -1,5 +1,11 @@ package controllers +import ( + "gonum.org/v1/gonum/floats" + "gonum.org/v1/gonum/stat" + "gonum.org/v1/gonum/stat/distuv" +) + import ( "errors" "github.com/gin-gonic/gin" @@ -7,9 +13,21 @@ import ( "gorm.io/gorm/clause" "hr_receiver/config" "hr_receiver/models" + "math" "net/http" ) +var analyzeRunTypes = []string{"6.5开始", "7开始", "8开始"} // 替换为你的具体值 + +func contains(s string) bool { + for _, item := range analyzeRunTypes { + if item == s { + return true + } + } + return false +} + type TrainingController struct { DB *gorm.DB } @@ -58,9 +76,11 @@ func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) { return err } } - err := tc.heartRateAnalyze(tx, record) - if err != nil { - return err + if contains(record.RunType) { + err := tc.heartRateAnalyze(tx, record) + if err != nil { + return err + } } return nil @@ -77,6 +97,120 @@ func (tc *TrainingController) CreateTrainingRecord(c *gin.Context) { }) } +// analysis_response.go +type AnalysisResponse struct { + Status string `json:"status"` // 状态码 + Message string `json:"message"` // 附加信息 + Data struct { + Mean float64 `json:"mean"` // 均值 + StdDev float64 `json:"stdDev"` // 标准差 + Histogram []HistoBin `json:"histogram"` // 直方图数据 + Curve []CurvePoint `json:"curve"` // 正态曲线数据 + } `json:"data"` +} + +type HistoBin struct { + BinStart float64 `json:"binStart"` // 区间起始值 + BinEnd float64 `json:"binEnd"` // 区间结束值 + Count int `json:"count"` // 该区间计数 +} + +type CurvePoint struct { + X float64 `json:"x"` // X坐标 + Y float64 `json:"y"` // Y坐标 +} + +// analysis_handler.go +func (tc *TrainingController) HandleCurveAnalysis(c *gin.Context) { + // 获取数据库连接(根据实际项目配置调整) + + // 1. 获取历史数据 + aValues, err := collectCurveParams(tc.DB) + if err != nil { + c.JSON(500, gin.H{ + "status": "error", + "message": "数据查询失败: " + err.Error(), + }) + return + } + + // 2. 检查数据有效性 + if len(aValues) < 10 { // 至少需要10个样本 + c.JSON(400, gin.H{ + "status": "fail", + "message": "数据量不足,至少需要10个样本", + }) + return + } + + // 3. 计算统计量 + mean, stddev := calculateStats(aValues) + + // 4. 生成直方图数据 + histogram := calculateHistogram(aValues, 20) // 20个分箱 + + // 5. 生成正态曲线 + x, y := generateNormalCurve(mean, stddev, 100) + + // 6. 构造响应 + response := AnalysisResponse{ + Status: "success", + Message: "分析完成", + Data: struct { + Mean float64 `json:"mean"` + StdDev float64 `json:"stdDev"` + Histogram []HistoBin `json:"histogram"` + Curve []CurvePoint `json:"curve"` + }{ + Mean: mean, + StdDev: stddev, + Histogram: histogram, + Curve: convertToCurvePoints(x, y), + }, + } + + c.JSON(200, response) +} + +// 直方图计算函数 +func calculateHistogram(data []float64, bins int) []HistoBin { + minV, maxV := floats.Min(data), floats.Max(data) + binWidth := (maxV - minV) / float64(bins) + + counts := make([]int, bins) + for _, v := range data { + idx := int((v - minV) / binWidth) + if idx == bins { // 处理最大值刚好等于maxV的情况 + idx-- + } + counts[idx]++ + } + + histogram := make([]HistoBin, bins) + for i := 0; i < bins; i++ { + start := minV + float64(i)*binWidth + end := minV + float64(i+1)*binWidth + histogram[i] = HistoBin{ + BinStart: start, + BinEnd: end, + Count: counts[i], + } + } + return histogram +} + +// 转换曲线数据格式 +func convertToCurvePoints(x, y []float64) []CurvePoint { + points := make([]CurvePoint, len(x)) + for i := range x { + points[i] = CurvePoint{ + X: x[i], + Y: y[i], + } + } + return points +} + func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error { startTime := record.StartTime @@ -116,36 +250,47 @@ func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainR return nil } -func ReceiveTrainingData(c *gin.Context) { - var data models.TrainingData - - if err := c.ShouldBindJSON(&data); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": "Invalid request body: " + err.Error(), - }) - return +func collectCurveParams(tx *gorm.DB) ([]float64, error) { + var aValues []float64 + // 查询所有记录的 CurveParamA 字段 + err := tx.Model(&models.BeltAnalysis{}).Pluck("curve_param_a", &aValues).Error + if err != nil { + return nil, err } - - if result := config.DB.Create(&data); result.Error != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to save data: " + result.Error.Error(), - }) - return - } - - c.JSON(http.StatusCreated, gin.H{ - "message": "Data saved successfully", - "id": data.ID, - }) + return aValues, nil +} +func calculateStats(data []float64) (mean, stddev float64) { + mean = stat.Mean(data, nil) + variance := stat.Variance(data, nil) + stddev = math.Sqrt(variance) + return } +func generateNormalCurve(mean, stddev float64, numPoints int) (x, y []float64) { + normal := distuv.Normal{ + Mu: mean, + Sigma: stddev, + } + + minV := mean - 3*stddev // 从均值-3σ开始 + maxV := mean + 3*stddev // 到均值+3σ结束 + step := (maxV - minV) / float64(numPoints-1) + + for i := 0; i < numPoints; i++ { + xi := minV + float64(i)*step + yi := normal.Prob(xi) + x = append(x, xi) + y = append(y, yi) + } + return +} func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string]TimeRange) (map[string]float64, error) { averages := make(map[string]float64) for key, tr := range ranges { var avg float64 // 使用GORM Raw SQL提高效率[6,10](@ref) err := tx.Raw(` - SELECT AVG(value) + SELECT COALESCE(AVG(value), 0) AS avg -- 关键修复 FROM heart_rates WHERE train_id = ? AND belt_id = ? diff --git a/go.mod b/go.mod index 5679dac..a545bb8 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,9 @@ go 1.23.3 require ( github.com/gin-gonic/gin v1.10.0 - github.com/golang-jwt/jwt/v4 v4.5.1 github.com/golang-jwt/jwt/v5 v5.2.1 github.com/spf13/viper v1.20.0 + gonum.org/v1/gonum v0.16.0 gorm.io/driver/postgres v1.5.11 gorm.io/gorm v1.25.12 ) @@ -51,9 +51,9 @@ require ( golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.32.0 // indirect golang.org/x/net v0.33.0 // indirect - golang.org/x/sync v0.10.0 // indirect + golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.29.0 // indirect - golang.org/x/text v0.21.0 // indirect + golang.org/x/text v0.23.0 // indirect google.golang.org/protobuf v1.36.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 92f4749..22c8773 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,6 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= -github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= @@ -114,14 +112,16 @@ golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/routes/routes.go b/routes/routes.go index e235785..4a69fe4 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -15,9 +15,10 @@ func SetupRouter() *gin.Engine { v1 := r.Group("/api/v1") { - records := v1.Group("/train-records").Use(middleware.AuthMiddleware()) + records := v1.Group("/train-records") //.Use(middleware.AuthMiddleware()) { records.POST("", trainingController.CreateTrainingRecord) + records.GET("/analysis", trainingController.HandleCurveAnalysis) // 可扩展其他路由:GET, PUT, DELETE等 } auth := v1.Group("/auth")