830 lines
21 KiB
Go
830 lines
21 KiB
Go
package controllers
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/sajari/regression"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/clause"
|
||
"hr_receiver/config"
|
||
"hr_receiver/models"
|
||
"log"
|
||
"math"
|
||
"net/http"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
)
|
||
|
||
type StepTrainingController struct {
|
||
DB *gorm.DB
|
||
}
|
||
|
||
func NewStepTrainingController() *StepTrainingController {
|
||
return &StepTrainingController{DB: config.DB}
|
||
}
|
||
|
||
// 接收训练记录
|
||
func (tc *StepTrainingController) CreateTrainingRecord(c *gin.Context) {
|
||
var record models.StepTrainRecord
|
||
|
||
// 绑定并验证JSON数据
|
||
if err := c.ShouldBindJSON(&record); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 使用事务保存数据[4](@ref)
|
||
err := tc.DB.Transaction(func(tx *gorm.DB) error {
|
||
// 保存主记录
|
||
if err := tx.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
|
||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||
"max_heart_rate": record.MaxHeartRate,
|
||
"start_time": record.StartTime,
|
||
"end_time": record.EndTime,
|
||
"duration": record.Duration,
|
||
"dead_zone": record.DeadZone,
|
||
"name": record.Name,
|
||
"evaluation": record.Evaluation,
|
||
}),
|
||
}).Omit("HeartRates", "StrideFreqs").Create(&record).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 保存关联的心率数据
|
||
for i := range record.HeartRates {
|
||
if err := tx.Clauses(
|
||
clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
|
||
DoUpdates: clause.Assignments(map[string]interface{}{"heart_rate_type": record.HeartRates[i].HeartRateType, "value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
|
||
},
|
||
).Create(&record.HeartRates[i]).Error; err != nil {
|
||
|
||
return err
|
||
}
|
||
}
|
||
for i := range record.StrideFreqs {
|
||
if err := tx.Clauses(
|
||
clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
|
||
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.StrideFreqs[i].Value, "time": record.StrideFreqs[i].Time}),
|
||
},
|
||
).Create(&record.StrideFreqs[i]).Error; err != nil {
|
||
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusCreated, gin.H{
|
||
"message": "数据保存成功",
|
||
"id": record.TrainId,
|
||
})
|
||
}
|
||
|
||
func (tc *StepTrainingController) GetTrainingRecords(c *gin.Context) {
|
||
// 定义分页参数结构
|
||
type PaginationParams struct {
|
||
PageNum int `form:"pageNum,default=1"` // 页码,默认第一页
|
||
PageSize int `form:"pageSize,default=10"` // 每页数量,默认10条
|
||
}
|
||
|
||
var params PaginationParams
|
||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 验证分页参数有效性
|
||
if params.PageNum < 1 {
|
||
params.PageNum = 1
|
||
}
|
||
if params.PageSize < 1 || params.PageSize > 100 {
|
||
params.PageSize = 10
|
||
}
|
||
|
||
// 计算偏移量
|
||
offset := (params.PageNum - 1) * params.PageSize
|
||
|
||
var (
|
||
records []models.StepTrainRecord
|
||
totalRows int64
|
||
)
|
||
|
||
// 获取总记录数
|
||
if err := tc.DB.Model(&models.StepTrainRecord{}).Count(&totalRows).Error; err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取记录总数失败"})
|
||
return
|
||
}
|
||
|
||
// 查询分页数据(按开始时间倒序排列)
|
||
result := tc.DB.
|
||
Order("start_time DESC"). // 按开始时间倒序
|
||
Offset(offset).
|
||
Limit(params.PageSize).
|
||
Find(&records)
|
||
|
||
if result.Error != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
|
||
return
|
||
}
|
||
|
||
// 计算总页数
|
||
totalPages := int(math.Ceil(float64(totalRows) / float64(params.PageSize)))
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": "查询成功",
|
||
"data": gin.H{
|
||
"list": records,
|
||
"pagination": gin.H{
|
||
"currentPage": params.PageNum,
|
||
"pageSize": params.PageSize,
|
||
"totalPage": totalPages,
|
||
"totalList": totalRows,
|
||
},
|
||
},
|
||
})
|
||
}
|
||
func (tc *StepTrainingController) GetTrainingRecordByTrainId(c *gin.Context) {
|
||
// 从URL路径参数获取trainId
|
||
trainId := c.Param("trainId")
|
||
if trainId == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "训练ID不能为空"})
|
||
return
|
||
}
|
||
|
||
// 将字符串trainId转换为uint类型
|
||
tid, err := strconv.ParseInt(trainId, 10, 64)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID格式"})
|
||
return
|
||
}
|
||
|
||
var record models.StepTrainRecord
|
||
|
||
// 查询主记录并预加载关联的心率和步频数据
|
||
result := tc.DB.Where("train_id = ?", uint(tid)).
|
||
Preload("HeartRates").
|
||
Preload("StrideFreqs").
|
||
First(&record)
|
||
|
||
if result.Error != nil {
|
||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "训练记录不存在"})
|
||
} else {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
|
||
}
|
||
return
|
||
}
|
||
|
||
// 成功返回数据
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": "查询成功",
|
||
"data": record,
|
||
})
|
||
}
|
||
|
||
// 定义结构体
|
||
type SpeedSegment struct {
|
||
Duration float64
|
||
Speed float64
|
||
}
|
||
|
||
// 实现线性回归算法
|
||
func performLinearRegression(averages []map[float64]float64) models.RegressionResult {
|
||
if len(averages) == 0 {
|
||
return models.RegressionResult{
|
||
Equation: "无数据",
|
||
}
|
||
}
|
||
|
||
// 收集数据点
|
||
var points []struct{ x, y float64 }
|
||
for _, m := range averages {
|
||
for x, y := range m {
|
||
points = append(points, struct{ x, y float64 }{x, y})
|
||
}
|
||
}
|
||
|
||
// 使用回归库计算
|
||
r := new(regression.Regression)
|
||
r.SetObserved("y")
|
||
r.SetVar(0, "x")
|
||
|
||
for _, p := range points {
|
||
r.Train(regression.DataPoint(p.y, []float64{p.x}))
|
||
}
|
||
|
||
if err := r.Run(); err != nil {
|
||
log.Printf("线性回归计算失败: %v", err)
|
||
return models.RegressionResult{
|
||
Equation: "计算失败",
|
||
}
|
||
}
|
||
|
||
// 创建结果
|
||
slope := r.Coeff(0)
|
||
intercept := r.Coeff(1)
|
||
r2 := r.R2
|
||
return models.RegressionResult{
|
||
RegressionType: models.LinearRegression,
|
||
Slope: &slope,
|
||
Intercept: &intercept,
|
||
RSquared: &r2,
|
||
Equation: r.Formula,
|
||
}
|
||
}
|
||
|
||
// 实现对数和二次回归算法
|
||
// 对数回归算法
|
||
func performLogarithmicRegression(averages []map[float64]float64) models.RegressionResult {
|
||
if len(averages) == 0 {
|
||
return models.RegressionResult{
|
||
Equation: "无数据",
|
||
}
|
||
}
|
||
|
||
// 收集数据点
|
||
r := new(regression.Regression)
|
||
r.SetObserved("y")
|
||
r.SetVar(0, "log(x+1)")
|
||
|
||
for _, m := range averages {
|
||
for speed, hr := range m {
|
||
logSpeed := math.Log(speed + 1)
|
||
r.Train(regression.DataPoint(hr, []float64{logSpeed}))
|
||
}
|
||
}
|
||
|
||
if err := r.Run(); err != nil {
|
||
log.Printf("对数回归计算失败: %v", err)
|
||
return models.RegressionResult{
|
||
Equation: "计算失败",
|
||
}
|
||
}
|
||
|
||
// 创建结果
|
||
logA := r.Coeff(1)
|
||
logB := r.Coeff(0)
|
||
r2 := r.R2
|
||
return models.RegressionResult{
|
||
RegressionType: models.LogarithmicRegression,
|
||
LogA: &logA,
|
||
LogB: &logB,
|
||
RSquared: &r2,
|
||
Equation: r.Formula,
|
||
}
|
||
}
|
||
|
||
// 二次回归算法
|
||
//func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
|
||
// if len(averages) == 0 {
|
||
// return models.RegressionResult{
|
||
// Equation: "无数据",
|
||
// }
|
||
// }
|
||
//
|
||
// // 收集数据点
|
||
// r := new(regression.Regression)
|
||
// r.SetObserved("y")
|
||
// r.SetVar(0, "x")
|
||
// r.SetVar(1, "x²")
|
||
//
|
||
// for _, m := range averages {
|
||
// for speed, hr := range m {
|
||
// speedSq := math.Pow(speed, 2)
|
||
// r.Train(regression.DataPoint(hr, []float64{speed, speedSq}))
|
||
// }
|
||
// }
|
||
//
|
||
// if err := r.Run(); err != nil {
|
||
// log.Printf("二次回归计算失败: %v", err)
|
||
// return models.RegressionResult{
|
||
// Equation: "计算失败",
|
||
// }
|
||
// }
|
||
//
|
||
// // 创建结果
|
||
// a := r.Coeff(2)
|
||
// b := r.Coeff(1)
|
||
// c := r.Coeff(0)
|
||
// r2 := r.R2
|
||
// return models.RegressionResult{
|
||
// RegressionType: models.QuadraticRegression,
|
||
// QuadraticA: &a,
|
||
// QuadraticB: &b,
|
||
// QuadraticC: &c,
|
||
// RSquared: &r2,
|
||
// Equation: r.Formula,
|
||
// }
|
||
//}
|
||
|
||
func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
|
||
if len(averages) == 0 {
|
||
return models.RegressionResult{
|
||
Equation: "无数据",
|
||
}
|
||
}
|
||
|
||
// 步骤1:收集所有数据点(与Flutter一致)
|
||
var xValues []float64
|
||
var yValues []float64
|
||
for _, m := range averages {
|
||
for speed, hr := range m {
|
||
xValues = append(xValues, speed)
|
||
yValues = append(yValues, hr)
|
||
}
|
||
}
|
||
n := float64(len(xValues))
|
||
|
||
// 步骤2:计算各项和(完全匹配Flutter的计算)
|
||
var sumX, sumY, sumX2, sumX3, sumX4, sumXY, sumX2Y float64
|
||
for i := 0; i < len(xValues); i++ {
|
||
x := xValues[i]
|
||
y := yValues[i]
|
||
x2 := x * x
|
||
x3 := x2 * x
|
||
x4 := x3 * x
|
||
|
||
sumX += x
|
||
sumY += y
|
||
sumX2 += x2
|
||
sumX3 += x3
|
||
sumX4 += x4
|
||
sumXY += x * y
|
||
sumX2Y += x2 * y
|
||
}
|
||
|
||
// 步骤3:构建正规方程矩阵(与Flutter完全一致)
|
||
matrix := [3][3]float64{
|
||
{n, sumX, sumX2},
|
||
{sumX, sumX2, sumX3},
|
||
{sumX2, sumX3, sumX4},
|
||
}
|
||
vector := []float64{sumY, sumXY, sumX2Y}
|
||
|
||
// 步骤4:计算矩阵行列式(复制Flutter的determinant3x3逻辑)
|
||
det := matrix[0][0]*(matrix[1][1]*matrix[2][2]-matrix[1][2]*matrix[2][1]) -
|
||
matrix[0][1]*(matrix[1][0]*matrix[2][2]-matrix[1][2]*matrix[2][0]) +
|
||
matrix[0][2]*(matrix[1][0]*matrix[2][1]-matrix[1][1]*matrix[2][0])
|
||
|
||
if det == 0 {
|
||
return models.RegressionResult{
|
||
Equation: "无法拟合",
|
||
}
|
||
}
|
||
|
||
// 步骤5:克莱姆法则求解系数(顺序与Flutter一致)
|
||
// 注意:最终系数顺序 a=二次项, b=一次项, c=常数项
|
||
c := det3x3([3][3]float64{
|
||
{vector[0], matrix[0][1], matrix[0][2]},
|
||
{vector[1], matrix[1][1], matrix[1][2]},
|
||
{vector[2], matrix[2][1], matrix[2][2]},
|
||
}) / det
|
||
|
||
b := det3x3([3][3]float64{
|
||
{matrix[0][0], vector[0], matrix[0][2]},
|
||
{matrix[1][0], vector[1], matrix[1][2]},
|
||
{matrix[2][0], vector[2], matrix[2][2]},
|
||
}) / det
|
||
|
||
a := det3x3([3][3]float64{
|
||
{matrix[0][0], matrix[0][1], vector[0]},
|
||
{matrix[1][0], matrix[1][1], vector[1]},
|
||
{matrix[2][0], matrix[2][1], vector[2]},
|
||
}) / det
|
||
|
||
// 步骤6:计算R平方(完全复制Flutter的计算逻辑)
|
||
var ssRes, ssTot float64
|
||
meanY := sumY / n
|
||
for i := 0; i < len(xValues); i++ {
|
||
x := xValues[i]
|
||
y := yValues[i]
|
||
yPred := a*x*x + b*x + c
|
||
ssRes += math.Pow(y-yPred, 2)
|
||
ssTot += math.Pow(y-meanY, 2)
|
||
}
|
||
rSquared := 0.0
|
||
if ssTot != 0 {
|
||
rSquared = 1 - ssRes/ssTot
|
||
}
|
||
|
||
// 步骤7:格式化公式字符串(与Flutter格式完全一致)
|
||
equation := formatEquation(a, b, c, rSquared)
|
||
|
||
return models.RegressionResult{
|
||
RegressionType: models.QuadraticRegression,
|
||
QuadraticA: &a,
|
||
QuadraticB: &b,
|
||
QuadraticC: &c,
|
||
RSquared: &rSquared,
|
||
Equation: equation,
|
||
}
|
||
}
|
||
|
||
// 3x3行列式计算(与Flutter实现相同)
|
||
func det3x3(m [3][3]float64) float64 {
|
||
return m[0][0]*(m[1][1]*m[2][2]-m[1][2]*m[2][1]) -
|
||
m[0][1]*(m[1][0]*m[2][2]-m[1][2]*m[2][0]) +
|
||
m[0][2]*(m[1][0]*m[2][1]-m[1][1]*m[2][0])
|
||
}
|
||
|
||
// 公式格式化(完全匹配Flutter格式)
|
||
func formatEquation(a, b, c, r2 float64) string {
|
||
// 保留4位小数
|
||
aStr := fmt.Sprintf("%.4f", a)
|
||
bStr := fmt.Sprintf("%.4f", b)
|
||
cStr := fmt.Sprintf("%.4f", c)
|
||
r2Str := fmt.Sprintf("%.4f", r2)
|
||
|
||
builder := strings.Builder{}
|
||
builder.WriteString("y = ")
|
||
|
||
// 处理二次项
|
||
if a >= 0 {
|
||
builder.WriteString(aStr + " x²")
|
||
} else {
|
||
builder.WriteString("-" + strings.TrimPrefix(aStr, "-") + " x²")
|
||
}
|
||
|
||
// 处理一次项
|
||
if b >= 0 {
|
||
builder.WriteString(" + " + bStr + " x")
|
||
} else {
|
||
builder.WriteString(" - " + strings.TrimPrefix(bStr, "-") + " x")
|
||
}
|
||
|
||
// 处理常数项
|
||
if c >= 0 {
|
||
builder.WriteString(" + " + cStr)
|
||
} else {
|
||
builder.WriteString(" - " + strings.TrimPrefix(cStr, "-"))
|
||
}
|
||
|
||
builder.WriteString(" (R² = " + r2Str + ")")
|
||
return builder.String()
|
||
}
|
||
|
||
// 步频数据转换为速度段
|
||
func convertStrideFrequencyToSegments(steps []models.StepStrideFreq) []SpeedSegment {
|
||
if len(steps) == 0 {
|
||
return []SpeedSegment{}
|
||
}
|
||
|
||
// 过滤零值并排序
|
||
validSteps := make([]models.StepStrideFreq, 0, len(steps))
|
||
for _, s := range steps {
|
||
if s.Value > 0 {
|
||
validSteps = append(validSteps, s)
|
||
}
|
||
}
|
||
|
||
if len(validSteps) == 0 {
|
||
return []SpeedSegment{}
|
||
}
|
||
|
||
// 按时间排序
|
||
for i := 0; i < len(validSteps)-1; i++ {
|
||
for j := i + 1; j < len(validSteps); j++ {
|
||
if validSteps[i].Time > validSteps[j].Time {
|
||
validSteps[i], validSteps[j] = validSteps[j], validSteps[i]
|
||
}
|
||
}
|
||
}
|
||
|
||
// 创建速度段
|
||
segments := make([]SpeedSegment, 0)
|
||
startTime := validSteps[0].Time
|
||
currentValue := validSteps[0].Value
|
||
|
||
for i := 1; i < len(validSteps); i++ {
|
||
if validSteps[i].Value != currentValue {
|
||
duration := float64(validSteps[i].Time-startTime) / 1000.0
|
||
if duration > 0 {
|
||
segments = append(segments, SpeedSegment{
|
||
Duration: duration,
|
||
Speed: float64(currentValue),
|
||
})
|
||
}
|
||
startTime = validSteps[i].Time
|
||
currentValue = validSteps[i].Value
|
||
}
|
||
}
|
||
|
||
// 添加最后一个段
|
||
if len(validSteps) > 0 {
|
||
duration := float64(validSteps[len(validSteps)-1].Time-startTime) / 1000.0
|
||
if duration > 0 {
|
||
segments = append(segments, SpeedSegment{
|
||
Duration: duration,
|
||
Speed: float64(currentValue),
|
||
})
|
||
}
|
||
}
|
||
|
||
return segments
|
||
}
|
||
|
||
// 计算区段平均值
|
||
func calculateSegmentAverages(heartRates []models.StepHeartRate, segments []SpeedSegment, errorThreshold int) []map[float64]float64 {
|
||
currentTime := 0.0
|
||
results := make([]map[float64]float64, 0)
|
||
|
||
for _, seg := range segments {
|
||
minRequired := 60 + (60 - float64(errorThreshold))
|
||
|
||
// 跳过不满足条件的区段
|
||
if seg.Duration < minRequired {
|
||
currentTime += seg.Duration
|
||
continue
|
||
}
|
||
|
||
// 计算时间窗口
|
||
startSec := currentTime + 60
|
||
endSec := currentTime
|
||
if seg.Duration >= 120 {
|
||
endSec = currentTime + 120
|
||
} else {
|
||
endSec = currentTime + 120 - float64(errorThreshold)
|
||
}
|
||
|
||
// 收集该区段的心率数据
|
||
sum, count := 0, 0
|
||
for _, hr := range heartRates {
|
||
sec := float64(hr.Time) / 1000.0
|
||
if sec >= startSec && sec <= endSec {
|
||
sum += hr.Value
|
||
count++
|
||
}
|
||
}
|
||
|
||
// 计算平均值
|
||
if count > 0 {
|
||
avg := float64(sum) / float64(count)
|
||
results = append(results, map[float64]float64{seg.Speed: avg})
|
||
}
|
||
|
||
currentTime += seg.Duration
|
||
}
|
||
|
||
return results
|
||
}
|
||
|
||
// 计算步频区段的心率平均值
|
||
func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps []models.StepStrideFreq) []map[float64]float64 {
|
||
segments := convertStrideFrequencyToSegments(steps)
|
||
return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值
|
||
}
|
||
|
||
// 存储回归结果到数据库(支持多种回归类型)
|
||
func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error {
|
||
return tc.DB.Transaction(func(tx *gorm.DB) error {
|
||
for i := range results {
|
||
results[i].TrainId = trainId
|
||
// 使用复合唯一约束确保每种回归类型只存储一条记录
|
||
err := tx.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{
|
||
{Name: "id"},
|
||
},
|
||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||
"equation": results[i].Equation,
|
||
"slope": results[i].Slope,
|
||
"intercept": results[i].Intercept,
|
||
"log_a": results[i].LogA,
|
||
"log_b": results[i].LogB,
|
||
"quadratic_a": results[i].QuadraticA,
|
||
"quadratic_b": results[i].QuadraticB,
|
||
"quadratic_c": results[i].QuadraticC,
|
||
"r_squared": results[i].RSquared,
|
||
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||
}),
|
||
}).Create(&results[i]).Error
|
||
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// 获取或计算回归结果(返回多种回归类型列表)
|
||
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) {
|
||
// 尝试从数据库获取所有类型的回归结果
|
||
var results []models.RegressionResult
|
||
err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error
|
||
|
||
// 如果已存在三种类型的结果,直接返回
|
||
if err == nil && len(results) >= 3 {
|
||
return results, nil
|
||
}
|
||
|
||
// 查询训练记录及相关数据
|
||
var record models.StepTrainRecord
|
||
if err := tc.DB.
|
||
Where("train_id = ?", uint(trainId)).
|
||
Preload("HeartRates", "heart_rate_type = ?", 1).
|
||
Preload("StrideFreqs", "predict_value = ?", 1).
|
||
First(&record).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 计算心率平均值
|
||
averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs)
|
||
if len(averages) == 0 {
|
||
return nil, errors.New("无足够数据进行回归计算")
|
||
}
|
||
|
||
// 创建三种回归类型的结果
|
||
results = make([]models.RegressionResult, 3)
|
||
|
||
// 线性回归
|
||
linearRes := performLinearRegression(averages)
|
||
results[0] = models.RegressionResult{
|
||
RegressionType: models.LinearRegression,
|
||
TrainId: trainId,
|
||
Equation: linearRes.Equation,
|
||
Slope: linearRes.Slope,
|
||
Intercept: linearRes.Intercept,
|
||
RSquared: linearRes.RSquared,
|
||
}
|
||
|
||
// 对数回归
|
||
logRes := performLogarithmicRegression(averages)
|
||
results[1] = models.RegressionResult{
|
||
RegressionType: models.LogarithmicRegression,
|
||
TrainId: trainId,
|
||
Equation: logRes.Equation,
|
||
LogA: logRes.LogA,
|
||
LogB: logRes.LogB,
|
||
RSquared: logRes.RSquared,
|
||
}
|
||
|
||
// 二次回归
|
||
quadRes := performQuadraticRegression(averages)
|
||
results[2] = models.RegressionResult{
|
||
RegressionType: models.QuadraticRegression,
|
||
TrainId: trainId,
|
||
Equation: quadRes.Equation,
|
||
QuadraticA: quadRes.QuadraticA,
|
||
QuadraticB: quadRes.QuadraticB,
|
||
QuadraticC: quadRes.QuadraticC,
|
||
RSquared: quadRes.RSquared,
|
||
}
|
||
|
||
// 批量保存结果到数据库
|
||
if err := tc.SaveRegressionResults(trainId, results); err != nil {
|
||
log.Printf("保存回归结果失败: %v", err)
|
||
return nil, err
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// 新增接口:获取回归结果
|
||
func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) {
|
||
trainIdStr := c.Param("trainId")
|
||
tid, err := strconv.ParseUint(trainIdStr, 10, 32)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
||
return
|
||
}
|
||
|
||
result, err := tc.GetOrCalculateRegression(uint(tid))
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": "获取成功",
|
||
"data": result,
|
||
})
|
||
}
|
||
|
||
// 获取训练记录的排名
|
||
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
|
||
// 解析参数
|
||
trainIdStr := c.Param("trainId")
|
||
|
||
regressionTypeStr := c.Query("type")
|
||
regressionType, err := strconv.Atoi(regressionTypeStr) // 字符串转整型
|
||
if err != nil {
|
||
// 转换失败时返回400错误
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
|
||
return
|
||
}
|
||
|
||
regType := models.RegressionType(regressionType)
|
||
|
||
// 验证回归类型
|
||
if regType != models.LinearRegression && regType != models.QuadraticRegression {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
|
||
return
|
||
}
|
||
|
||
// 转换trainId
|
||
tid, err := strconv.ParseUint(trainIdStr, 10, 64)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
||
return
|
||
}
|
||
trainId := uint(tid)
|
||
|
||
// 确保回归结果存在
|
||
if _, err := tc.GetOrCalculateRegression(trainId); err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
|
||
return
|
||
}
|
||
|
||
// 获取所有记录用于排名
|
||
var records []models.RegressionResult
|
||
query := tc.DB.Model(&models.RegressionResult{})
|
||
switch regType {
|
||
case models.LinearRegression:
|
||
query = query.Where("slope IS NOT NULL").Select("train_id, slope,regression_type")
|
||
case models.QuadraticRegression:
|
||
query = query.Where("quadratic_a IS NOT NULL").
|
||
Select("train_id, ABS(quadratic_a),regression_type")
|
||
}
|
||
if err := query.Debug().Find(&records).Error; err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"})
|
||
return
|
||
}
|
||
|
||
// 处理无数据情况
|
||
if len(records) == 0 {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"})
|
||
return
|
||
}
|
||
|
||
sort.Slice(records, func(i, j int) bool {
|
||
// 处理空指针情况
|
||
slopeI := records[i].Slope
|
||
slopeJ := records[j].Slope
|
||
|
||
if slopeI == nil && slopeJ == nil {
|
||
return false // 两者均为空时视为相等
|
||
}
|
||
if slopeI == nil {
|
||
return false // 空值视为极大值(排在最后)
|
||
}
|
||
if slopeJ == nil {
|
||
return true // 非空值始终排在前
|
||
}
|
||
|
||
return *slopeI > *slopeJ
|
||
})
|
||
|
||
// 计算排名(处理并列)
|
||
currentRank := 1
|
||
rankMap := make(map[uint]int)
|
||
for i, record := range records {
|
||
// 处理第一个记录
|
||
if i == 0 {
|
||
rankMap[record.TrainId] = currentRank
|
||
continue
|
||
}
|
||
|
||
if regType == models.LinearRegression {
|
||
if records[i].Slope != records[i-1].Slope {
|
||
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1
|
||
}
|
||
rankMap[record.TrainId] = currentRank
|
||
} else {
|
||
if records[i].QuadraticA != records[i-1].QuadraticA {
|
||
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1
|
||
}
|
||
rankMap[record.TrainId] = currentRank
|
||
}
|
||
// 检测值是否变化
|
||
}
|
||
|
||
// 获取当前训练记录的排名
|
||
rank, exists := rankMap[trainId]
|
||
if !exists {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"})
|
||
return
|
||
}
|
||
|
||
// 返回响应
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"message": "排名查询成功",
|
||
"data": gin.H{
|
||
"trainId": trainId,
|
||
"type": regressionType,
|
||
"rank": rank,
|
||
"total": len(records),
|
||
},
|
||
})
|
||
}
|