Files
hr_data_analyzer/controllers/step_train.go
2025-08-05 10:55:19 +08:00

896 lines
23 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package controllers
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/sajari/regression"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"hr_receiver/config"
"hr_receiver/models"
"log"
"math"
"net/http"
"strconv"
"strings"
)
type StepTrainingController struct {
DB *gorm.DB
}
func NewStepTrainingController() *StepTrainingController {
return &StepTrainingController{DB: config.DB}
}
// 接收训练记录
func (tc *StepTrainingController) CreateTrainingRecord(c *gin.Context) {
var record models.StepTrainRecord
// 绑定并验证JSON数据
if err := c.ShouldBindJSON(&record); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 使用事务保存数据[4](@ref)
err := tc.DB.Transaction(func(tx *gorm.DB) error {
// 保存主记录
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{
"max_heart_rate": record.MaxHeartRate,
"start_time": record.StartTime,
"end_time": record.EndTime,
"duration": record.Duration,
"dead_zone": record.DeadZone,
"name": record.Name,
"evaluation": record.Evaluation,
}),
}).Omit("HeartRates", "StrideFreqs").Create(&record).Error; err != nil {
return err
}
// 保存关联的心率数据
for i := range record.HeartRates {
if err := tx.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{"heart_rate_type": record.HeartRates[i].HeartRateType, "value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
},
).Create(&record.HeartRates[i]).Error; err != nil {
return err
}
}
for i := range record.StrideFreqs {
if err := tx.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.StrideFreqs[i].Value, "time": record.StrideFreqs[i].Time}),
},
).Create(&record.StrideFreqs[i]).Error; err != nil {
return err
}
}
return nil
})
// ====== 新增部分:启动异步回归计算 ======
go func() {
// 查询完整数据(需要关联的心率和步频数据)
var fullRecord models.StepTrainRecord
if err := tc.DB.
Where("train_id = ?", record.TrainId).
Preload("HeartRates", "heart_rate_type = ?", 1). // 只要有效心率
Preload("StrideFreqs", "predict_value = ?", 1). // 只要有效步频
First(&fullRecord).Error; err != nil {
log.Printf("训练记录%d查询失败无法计算回归: %v", record.TrainId, err)
return
}
// 检查数据是否满足计算条件
if len(fullRecord.HeartRates) == 0 || len(fullRecord.StrideFreqs) == 0 {
log.Printf("训练记录%d缺少心率或步频数据跳过回归计算", record.TrainId)
return
}
// 计算并保存回归结果
if _, err := tc.GetOrCalculateRegression(fullRecord.TrainId); err != nil {
log.Printf("训练记录%d回归计算失败: %v", fullRecord.TrainId, err)
} else {
log.Printf("训练记录%d回归结果已保存", fullRecord.TrainId)
}
}()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": "数据保存成功",
"id": record.TrainId,
})
}
func (tc *StepTrainingController) GetTrainingRecords(c *gin.Context) {
// 定义分页参数结构
type PaginationParams struct {
PageNum int `form:"pageNum,default=1"` // 页码,默认第一页
PageSize int `form:"pageSize,default=10"` // 每页数量默认10条
}
var params PaginationParams
if err := c.ShouldBindQuery(&params); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 验证分页参数有效性
if params.PageNum < 1 {
params.PageNum = 1
}
if params.PageSize < 1 || params.PageSize > 100 {
params.PageSize = 10
}
// 计算偏移量
offset := (params.PageNum - 1) * params.PageSize
var (
records []models.StepTrainRecord
totalRows int64
)
// 获取总记录数
if err := tc.DB.Model(&models.StepTrainRecord{}).Count(&totalRows).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取记录总数失败"})
return
}
// 查询分页数据(按开始时间倒序排列)
result := tc.DB.
Order("start_time DESC"). // 按开始时间倒序
Offset(offset).
Limit(params.PageSize).
Find(&records)
if result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
return
}
// 计算总页数
totalPages := int(math.Ceil(float64(totalRows) / float64(params.PageSize)))
c.JSON(http.StatusOK, gin.H{
"message": "查询成功",
"data": gin.H{
"list": records,
"pagination": gin.H{
"currentPage": params.PageNum,
"pageSize": params.PageSize,
"totalPage": totalPages,
"totalList": totalRows,
},
},
})
}
func (tc *StepTrainingController) GetTrainingRecordByTrainId(c *gin.Context) {
// 从URL路径参数获取trainId
trainId := c.Param("trainId")
if trainId == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "训练ID不能为空"})
return
}
// 将字符串trainId转换为uint类型
tid, err := strconv.ParseInt(trainId, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID格式"})
return
}
var record models.StepTrainRecord
// 查询主记录并预加载关联的心率和步频数据
result := tc.DB.Where("train_id = ?", uint(tid)).
Preload("HeartRates").
Preload("StrideFreqs").
First(&record)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": "训练记录不存在"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
}
return
}
// 成功返回数据
c.JSON(http.StatusOK, gin.H{
"message": "查询成功",
"data": record,
})
}
// 定义结构体
type SpeedSegment struct {
Duration float64
Speed float64
}
// 实现线性回归算法
func performLinearRegression(averages []map[float64]float64) models.RegressionResult {
if len(averages) == 0 {
return models.RegressionResult{
Equation: "无数据",
}
}
// 收集数据点
var points []struct{ x, y float64 }
for _, m := range averages {
for x, y := range m {
points = append(points, struct{ x, y float64 }{x, y})
}
}
// 使用回归库计算
r := new(regression.Regression)
r.SetObserved("y")
r.SetVar(0, "x")
for _, p := range points {
r.Train(regression.DataPoint(p.y, []float64{p.x}))
}
if err := r.Run(); err != nil {
log.Printf("线性回归计算失败: %v", err)
return models.RegressionResult{
Equation: "计算失败",
}
}
// 创建结果
slope := r.Coeff(1)
intercept := r.Coeff(0)
r2 := r.R2
return models.RegressionResult{
RegressionType: models.LinearRegression,
Slope: &slope,
Intercept: &intercept,
RSquared: &r2,
Equation: r.Formula,
}
}
// 实现对数和二次回归算法
// 对数回归算法
func performLogarithmicRegression(averages []map[float64]float64) models.RegressionResult {
if len(averages) == 0 {
return models.RegressionResult{
Equation: "无数据",
}
}
// 收集数据点
r := new(regression.Regression)
r.SetObserved("y")
r.SetVar(0, "log(x+1)")
for _, m := range averages {
for speed, hr := range m {
logSpeed := math.Log(speed + 1)
r.Train(regression.DataPoint(hr, []float64{logSpeed}))
}
}
if err := r.Run(); err != nil {
log.Printf("对数回归计算失败: %v", err)
return models.RegressionResult{
Equation: "计算失败",
}
}
// 创建结果
logA := r.Coeff(1)
logB := r.Coeff(0)
r2 := r.R2
return models.RegressionResult{
RegressionType: models.LogarithmicRegression,
LogA: &logA,
LogB: &logB,
RSquared: &r2,
Equation: r.Formula,
}
}
// 二次回归算法
//func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
// if len(averages) == 0 {
// return models.RegressionResult{
// Equation: "无数据",
// }
// }
//
// // 收集数据点
// r := new(regression.Regression)
// r.SetObserved("y")
// r.SetVar(0, "x")
// r.SetVar(1, "x²")
//
// for _, m := range averages {
// for speed, hr := range m {
// speedSq := math.Pow(speed, 2)
// r.Train(regression.DataPoint(hr, []float64{speed, speedSq}))
// }
// }
//
// if err := r.Run(); err != nil {
// log.Printf("二次回归计算失败: %v", err)
// return models.RegressionResult{
// Equation: "计算失败",
// }
// }
//
// // 创建结果
// a := r.Coeff(2)
// b := r.Coeff(1)
// c := r.Coeff(0)
// r2 := r.R2
// return models.RegressionResult{
// RegressionType: models.QuadraticRegression,
// QuadraticA: &a,
// QuadraticB: &b,
// QuadraticC: &c,
// RSquared: &r2,
// Equation: r.Formula,
// }
//}
func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
if len(averages) == 0 {
return models.RegressionResult{
Equation: "无数据",
}
}
// 步骤1收集所有数据点与Flutter一致
var xValues []float64
var yValues []float64
for _, m := range averages {
for speed, hr := range m {
xValues = append(xValues, speed)
yValues = append(yValues, hr)
}
}
n := float64(len(xValues))
// 步骤2计算各项和完全匹配Flutter的计算
var sumX, sumY, sumX2, sumX3, sumX4, sumXY, sumX2Y float64
for i := 0; i < len(xValues); i++ {
x := xValues[i]
y := yValues[i]
x2 := x * x
x3 := x2 * x
x4 := x3 * x
sumX += x
sumY += y
sumX2 += x2
sumX3 += x3
sumX4 += x4
sumXY += x * y
sumX2Y += x2 * y
}
// 步骤3构建正规方程矩阵与Flutter完全一致
matrix := [3][3]float64{
{n, sumX, sumX2},
{sumX, sumX2, sumX3},
{sumX2, sumX3, sumX4},
}
vector := []float64{sumY, sumXY, sumX2Y}
// 步骤4计算矩阵行列式复制Flutter的determinant3x3逻辑
det := matrix[0][0]*(matrix[1][1]*matrix[2][2]-matrix[1][2]*matrix[2][1]) -
matrix[0][1]*(matrix[1][0]*matrix[2][2]-matrix[1][2]*matrix[2][0]) +
matrix[0][2]*(matrix[1][0]*matrix[2][1]-matrix[1][1]*matrix[2][0])
if det == 0 {
return models.RegressionResult{
Equation: "无法拟合",
}
}
// 步骤5克莱姆法则求解系数顺序与Flutter一致
// 注意:最终系数顺序 a=二次项, b=一次项, c=常数项
c := det3x3([3][3]float64{
{vector[0], matrix[0][1], matrix[0][2]},
{vector[1], matrix[1][1], matrix[1][2]},
{vector[2], matrix[2][1], matrix[2][2]},
}) / det
b := det3x3([3][3]float64{
{matrix[0][0], vector[0], matrix[0][2]},
{matrix[1][0], vector[1], matrix[1][2]},
{matrix[2][0], vector[2], matrix[2][2]},
}) / det
a := det3x3([3][3]float64{
{matrix[0][0], matrix[0][1], vector[0]},
{matrix[1][0], matrix[1][1], vector[1]},
{matrix[2][0], matrix[2][1], vector[2]},
}) / det
// 步骤6计算R平方完全复制Flutter的计算逻辑
var ssRes, ssTot float64
meanY := sumY / n
for i := 0; i < len(xValues); i++ {
x := xValues[i]
y := yValues[i]
yPred := a*x*x + b*x + c
ssRes += math.Pow(y-yPred, 2)
ssTot += math.Pow(y-meanY, 2)
}
rSquared := 0.0
if ssTot != 0 {
rSquared = 1 - ssRes/ssTot
}
// 步骤7格式化公式字符串与Flutter格式完全一致
equation := formatEquation(a, b, c, rSquared)
return models.RegressionResult{
RegressionType: models.QuadraticRegression,
QuadraticA: &a,
QuadraticB: &b,
QuadraticC: &c,
RSquared: &rSquared,
Equation: equation,
}
}
// 3x3行列式计算与Flutter实现相同
func det3x3(m [3][3]float64) float64 {
return m[0][0]*(m[1][1]*m[2][2]-m[1][2]*m[2][1]) -
m[0][1]*(m[1][0]*m[2][2]-m[1][2]*m[2][0]) +
m[0][2]*(m[1][0]*m[2][1]-m[1][1]*m[2][0])
}
// 公式格式化完全匹配Flutter格式
func formatEquation(a, b, c, r2 float64) string {
// 保留4位小数
aStr := fmt.Sprintf("%.4f", a)
bStr := fmt.Sprintf("%.4f", b)
cStr := fmt.Sprintf("%.4f", c)
r2Str := fmt.Sprintf("%.4f", r2)
builder := strings.Builder{}
builder.WriteString("y = ")
// 处理二次项
if a >= 0 {
builder.WriteString(aStr + " x²")
} else {
builder.WriteString("-" + strings.TrimPrefix(aStr, "-") + " x²")
}
// 处理一次项
if b >= 0 {
builder.WriteString(" + " + bStr + " x")
} else {
builder.WriteString(" - " + strings.TrimPrefix(bStr, "-") + " x")
}
// 处理常数项
if c >= 0 {
builder.WriteString(" + " + cStr)
} else {
builder.WriteString(" - " + strings.TrimPrefix(cStr, "-"))
}
builder.WriteString(" (R² = " + r2Str + ")")
return builder.String()
}
// 步频数据转换为速度段
func convertStrideFrequencyToSegments(steps []models.StepStrideFreq) []SpeedSegment {
if len(steps) == 0 {
return []SpeedSegment{}
}
// 过滤零值并排序
validSteps := make([]models.StepStrideFreq, 0, len(steps))
for _, s := range steps {
if s.Value > 0 {
validSteps = append(validSteps, s)
}
}
if len(validSteps) == 0 {
return []SpeedSegment{}
}
// 按时间排序
for i := 0; i < len(validSteps)-1; i++ {
for j := i + 1; j < len(validSteps); j++ {
if validSteps[i].Time > validSteps[j].Time {
validSteps[i], validSteps[j] = validSteps[j], validSteps[i]
}
}
}
// 创建速度段
segments := make([]SpeedSegment, 0)
startTime := validSteps[0].Time
currentValue := validSteps[0].Value
for i := 1; i < len(validSteps); i++ {
if validSteps[i].Value != currentValue {
duration := float64(validSteps[i].Time-startTime) / 1000.0
if duration > 0 {
segments = append(segments, SpeedSegment{
Duration: duration,
Speed: float64(currentValue),
})
}
startTime = validSteps[i].Time
currentValue = validSteps[i].Value
}
}
// 添加最后一个段
if len(validSteps) > 0 {
duration := float64(validSteps[len(validSteps)-1].Time-startTime) / 1000.0
if duration > 0 {
segments = append(segments, SpeedSegment{
Duration: duration,
Speed: float64(currentValue),
})
}
}
return segments
}
// 计算区段平均值
func calculateSegmentAverages(heartRates []models.StepHeartRate, segments []SpeedSegment, errorThreshold int) []map[float64]float64 {
currentTime := 0.0
results := make([]map[float64]float64, 0)
for _, seg := range segments {
minRequired := 60 + (60 - float64(errorThreshold))
// 跳过不满足条件的区段
if seg.Duration < minRequired {
currentTime += seg.Duration
continue
}
// 计算时间窗口
startSec := currentTime + 60
endSec := currentTime
if seg.Duration >= 120 {
endSec = currentTime + 120
} else {
endSec = currentTime + 120 - float64(errorThreshold)
}
// 收集该区段的心率数据
sum, count := 0, 0
for _, hr := range heartRates {
sec := float64(hr.Time) / 1000.0
if sec >= startSec && sec <= endSec {
sum += hr.Value
count++
}
}
// 计算平均值
if count > 0 {
avg := float64(sum) / float64(count)
results = append(results, map[float64]float64{seg.Speed: avg})
}
currentTime += seg.Duration
}
return results
}
// 计算步频区段的心率平均值
func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps []models.StepStrideFreq) []map[float64]float64 {
segments := convertStrideFrequencyToSegments(steps)
return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值
}
// 存储回归结果到数据库(支持多种回归类型)
func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error {
return tc.DB.Transaction(func(tx *gorm.DB) error {
for i := range results {
results[i].TrainId = trainId
// 使用复合唯一约束确保每种回归类型只存储一条记录
err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"equation": results[i].Equation,
"slope": results[i].Slope,
"intercept": results[i].Intercept,
"log_a": results[i].LogA,
"log_b": results[i].LogB,
"quadratic_a": results[i].QuadraticA,
"quadratic_b": results[i].QuadraticB,
"quadratic_c": results[i].QuadraticC,
"r_squared": results[i].RSquared,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&results[i]).Error
if err != nil {
return err
}
}
return nil
})
}
// 获取或计算回归结果(返回多种回归类型列表)
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) {
// 尝试从数据库获取所有类型的回归结果
var results []models.RegressionResult
err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error
// 如果已存在三种类型的结果,直接返回
if err == nil && len(results) >= 3 {
return results, nil
}
// 查询训练记录及相关数据
var record models.StepTrainRecord
if err := tc.DB.
Where("train_id = ?", uint(trainId)).
Preload("HeartRates", "heart_rate_type = ?", 1).
Preload("StrideFreqs", "predict_value = ?", 1).
First(&record).Error; err != nil {
return nil, err
}
// 计算心率平均值
averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs)
if len(averages) == 0 {
return nil, errors.New("无足够数据进行回归计算")
}
// 创建三种回归类型的结果
results = make([]models.RegressionResult, 3)
// 线性回归
linearRes := performLinearRegression(averages)
results[0] = models.RegressionResult{
RegressionType: models.LinearRegression,
TrainId: trainId,
Equation: linearRes.Equation,
Slope: linearRes.Slope,
Intercept: linearRes.Intercept,
RSquared: linearRes.RSquared,
}
// 对数回归
logRes := performLogarithmicRegression(averages)
results[1] = models.RegressionResult{
RegressionType: models.LogarithmicRegression,
TrainId: trainId,
Equation: logRes.Equation,
LogA: logRes.LogA,
LogB: logRes.LogB,
RSquared: logRes.RSquared,
}
// 二次回归
quadRes := performQuadraticRegression(averages)
results[2] = models.RegressionResult{
RegressionType: models.QuadraticRegression,
TrainId: trainId,
Equation: quadRes.Equation,
QuadraticA: quadRes.QuadraticA,
QuadraticB: quadRes.QuadraticB,
QuadraticC: quadRes.QuadraticC,
RSquared: quadRes.RSquared,
}
// 批量保存结果到数据库
if err := tc.SaveRegressionResults(trainId, results); err != nil {
log.Printf("保存回归结果失败: %v", err)
return nil, err
}
return results, nil
}
// 新增接口:获取回归结果
func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) {
trainIdStr := c.Param("trainId")
tid, err := strconv.ParseUint(trainIdStr, 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
return
}
result, err := tc.GetOrCalculateRegression(uint(tid))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "获取成功",
"data": result,
})
}
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
// 参数解析
trainIdStr := c.Param("trainId")
regressionTypeStr := c.Query("type")
regressionType, err := strconv.Atoi(regressionTypeStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
return
}
// 验证回归类型
regType := models.RegressionType(regressionType)
if regType != models.LinearRegression && regType != models.QuadraticRegression {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
return
}
// 转换训练ID
tid, err := strconv.ParseUint(trainIdStr, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
return
}
trainId := uint(tid)
// 确保回归结果存在
if _, err := tc.GetOrCalculateRegression(trainId); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
return
}
// 获取指定训练的基准值
var baseValue float64
baseQuery := tc.DB.Model(&models.RegressionResult{}).
Select(getValueColumn(regType)).
Where("train_id = ?", trainId)
if err := baseQuery.Row().Scan(&baseValue); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": "指定的训练数据不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取基准数据失败"})
return
}
// 在数据库中进行排名计算
var rank struct {
BetterCount int64
Total int64
}
// 动态生成比较条件
betterCondition := fmt.Sprintf("%s %s ?",
getValueColumn(regType),
getComparisonOperator(regType))
totalQuery := tc.DB.Model(&models.RegressionResult{}).
Where(getTypeCondition(regType))
if err := totalQuery.Count(&rank.Total).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "统计总数失败"})
return
}
if err := tc.DB.Model(&models.RegressionResult{}).
Where(getTypeCondition(regType)).
Where(betterCondition, baseValue).
Count(&rank.BetterCount).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "计算排名失败"})
return
}
// 计算实际排名 (并列排名)
currentRank := rank.BetterCount + 1
// 返回响应
c.JSON(http.StatusOK, gin.H{
"message": "排名查询成功",
"data": gin.H{
"trainId": trainId,
"type": regressionType,
"rank": currentRank,
"total": rank.Total,
},
})
}
// 辅助函数:获取排序字段名
func getValueColumn(regType models.RegressionType) string {
switch regType {
case models.LinearRegression:
return "slope"
case models.QuadraticRegression:
return "ABS(quadratic_a)" // 计算绝对值
default:
return ""
}
}
// 辅助函数:获取比较操作符
func getComparisonOperator(regType models.RegressionType) string {
switch regType {
case models.LinearRegression:
return "<" // 线性回归:值越小越好
case models.QuadraticRegression:
return ">" // 二次回归:绝对值越大越好
default:
return ""
}
}
// 辅助函数:获取类型条件
func getTypeCondition(regType models.RegressionType) string {
switch regType {
case models.LinearRegression:
return "slope IS NOT NULL"
case models.QuadraticRegression:
return "quadratic_a IS NOT NULL AND quadratic_a < 0" // 确保是负值
default:
return ""
}
}
// 辅助函数:比较浮点指针(用于线性回归)
func compareFloatPtr(a, b *float64, ascending bool) bool {
if a == nil && b == nil {
return false
}
if a == nil {
return false // 空值排最后
}
if b == nil {
return true // 非空值排前
}
if ascending {
return *a < *b
}
return *a > *b
}
// 辅助函数:比较二次项系数(用于二次回归)
func compareQuadraticA(a, b *float64) bool {
if a == nil && b == nil {
return false
}
if a == nil {
return false
}
if b == nil {
return true
}
// 比较绝对值a和b都是负值所以取绝对值后大的排前面
return math.Abs(*a) > math.Abs(*b)
}