fix: linear slope.

This commit is contained in:
2025-08-05 10:18:35 +08:00
parent 3078c13e14
commit d3e87fda67

View File

@ -232,8 +232,8 @@ func performLinearRegression(averages []map[float64]float64) models.RegressionRe
} }
// 创建结果 // 创建结果
slope := r.Coeff(0) slope := r.Coeff(1)
intercept := r.Coeff(1) intercept := r.Coeff(0)
r2 := r.R2 r2 := r.R2
return models.RegressionResult{ return models.RegressionResult{
RegressionType: models.LinearRegression, RegressionType: models.LinearRegression,
@ -711,28 +711,24 @@ func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) {
}) })
} }
// 获取训练记录的排名
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) { func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
// 解析参数 // 参数解析
trainIdStr := c.Param("trainId") trainIdStr := c.Param("trainId")
regressionTypeStr := c.Query("type") regressionTypeStr := c.Query("type")
regressionType, err := strconv.Atoi(regressionTypeStr) // 字符串转整型 regressionType, err := strconv.Atoi(regressionTypeStr)
if err != nil { if err != nil {
// 转换失败时返回400错误
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"}) c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
return return
} }
regType := models.RegressionType(regressionType)
// 验证回归类型 // 验证回归类型
regType := models.RegressionType(regressionType)
if regType != models.LinearRegression && regType != models.QuadraticRegression { if regType != models.LinearRegression && regType != models.QuadraticRegression {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"}) c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
return return
} }
// 转换trainId // 转换训练ID
tid, err := strconv.ParseUint(trainIdStr, 10, 64) tid, err := strconv.ParseUint(trainIdStr, 10, 64)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"}) c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
@ -746,70 +742,80 @@ func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
return return
} }
// 获取所有记录用于排名 // 获取排名数据
var records []models.RegressionResult var records []models.RegressionResult
query := tc.DB.Model(&models.RegressionResult{}) query := tc.DB.Model(&models.RegressionResult{})
switch regType { switch regType {
case models.LinearRegression: case models.LinearRegression:
query = query.Where("slope IS NOT NULL").Select("train_id, slope,regression_type") query = query.Where("slope IS NOT NULL").
Select("train_id, slope, regression_type")
case models.QuadraticRegression: case models.QuadraticRegression:
query = query.Where("quadratic_a IS NOT NULL"). query = query.Where("quadratic_a IS NOT NULL AND quadratic_a < 0").
Select("train_id, ABS(quadratic_a),regression_type") Select("train_id, quadratic_a, regression_type")
} }
if err := query.Debug().Find(&records).Error; err != nil { if err := query.Find(&records).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"})
return return
} }
// 处理无数据情况
if len(records) == 0 { if len(records) == 0 {
c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"}) c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"})
return return
} }
// 排序逻辑
sort.Slice(records, func(i, j int) bool { sort.Slice(records, func(i, j int) bool {
// 处理空指针情况 if regType == models.LinearRegression {
slopeI := records[i].Slope // 线性回归:斜率越小排名越高
slopeJ := records[j].Slope return compareFloatPtr(records[i].Slope, records[j].Slope, true)
} else {
if slopeI == nil && slopeJ == nil { // 二次回归:|a|越大排名越高a为负值
return false // 两者均为空时视为相等 return compareQuadraticA(records[i].QuadraticA, records[j].QuadraticA)
} }
if slopeI == nil {
return false // 空值视为极大值(排在最后)
}
if slopeJ == nil {
return true // 非空值始终排在前
}
return *slopeI > *slopeJ
}) })
// 计算排名(处理并列) // 计算排名(处理并列)
currentRank := 1 currentRank := 1
rankMap := make(map[uint]int) rankMap := make(map[uint]int)
lastValue := math.Inf(-1) // 初始极小值
for i, record := range records { for i, record := range records {
// 处理第一个记录 var currentValue float64
var isSet bool
switch regType {
case models.LinearRegression:
if record.Slope != nil {
currentValue = *record.Slope
isSet = true
}
case models.QuadraticRegression:
if record.QuadraticA != nil {
currentValue = math.Abs(*record.QuadraticA)
isSet = true
}
}
if !isSet {
continue // 跳过空值(理论上不应出现)
}
if i == 0 { if i == 0 {
rankMap[record.TrainId] = currentRank rankMap[record.TrainId] = currentRank
lastValue = currentValue
continue continue
} }
if regType == models.LinearRegression { // 值变化时更新排名(跳过重复值)
if records[i].Slope != records[i-1].Slope { if math.Abs(currentValue-lastValue) > 1e-9 {
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1 currentRank = i + 1
} lastValue = currentValue
rankMap[record.TrainId] = currentRank
} else {
if records[i].QuadraticA != records[i-1].QuadraticA {
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1
}
rankMap[record.TrainId] = currentRank
} }
// 检测值是否变化
rankMap[record.TrainId] = currentRank
} }
// 获取当前训练记录的排名 // 获取指定训练记录的排名
rank, exists := rankMap[trainId] rank, exists := rankMap[trainId]
if !exists { if !exists {
c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"})
@ -827,3 +833,37 @@ func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
}, },
}) })
} }
// 辅助函数:比较浮点指针(用于线性回归)
func compareFloatPtr(a, b *float64, ascending bool) bool {
if a == nil && b == nil {
return false
}
if a == nil {
return false // 空值排最后
}
if b == nil {
return true // 非空值排前
}
if ascending {
return *a < *b
}
return *a > *b
}
// 辅助函数:比较二次项系数(用于二次回归)
func compareQuadraticA(a, b *float64) bool {
if a == nil && b == nil {
return false
}
if a == nil {
return false
}
if b == nil {
return true
}
// 比较绝对值a和b都是负值所以取绝对值后大的排前面
return math.Abs(*a) > math.Abs(*b)
}