fix: rank.

This commit is contained in:
2025-08-05 10:43:19 +08:00
parent d3e87fda67
commit 563fdd8a7e

View File

@ -12,7 +12,6 @@ import (
"log" "log"
"math" "math"
"net/http" "net/http"
"sort"
"strconv" "strconv"
"strings" "strings"
) )
@ -741,99 +740,99 @@ func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
return return
} }
// 获取指定训练的基准值
var baseValue float64
baseQuery := tc.DB.Model(&models.RegressionResult{}).
Select(getValueColumn(regType)).
Where("train_id = ?", trainId)
// 获取排名数据 if err := baseQuery.Row().Scan(&baseValue); err != nil {
var records []models.RegressionResult if errors.Is(err, gorm.ErrRecordNotFound) {
query := tc.DB.Model(&models.RegressionResult{}) c.JSON(http.StatusNotFound, gin.H{"error": "指定的训练数据不存在"})
switch regType { return
case models.LinearRegression:
query = query.Where("slope IS NOT NULL").
Select("train_id, slope, regression_type")
case models.QuadraticRegression:
query = query.Where("quadratic_a IS NOT NULL AND quadratic_a < 0").
Select("train_id, quadratic_a, regression_type")
} }
if err := query.Find(&records).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "获取基准数据失败"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"})
return return
} }
if len(records) == 0 { // 在数据库中进行排名计算
c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"}) var rank struct {
BetterCount int64
Total int64
}
// 动态生成比较条件
betterCondition := fmt.Sprintf("%s %s ?",
getValueColumn(regType),
getComparisonOperator(regType))
totalQuery := tc.DB.Model(&models.RegressionResult{}).
Where(getTypeCondition(regType))
if err := totalQuery.Count(&rank.Total).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "统计总数失败"})
return return
} }
// 排序逻辑 if err := tc.DB.Model(&models.RegressionResult{}).
sort.Slice(records, func(i, j int) bool { Where(getTypeCondition(regType)).
if regType == models.LinearRegression { Where(betterCondition, baseValue).
// 线性回归:斜率越小排名越高 Count(&rank.BetterCount).Error; err != nil {
return compareFloatPtr(records[i].Slope, records[j].Slope, true) c.JSON(http.StatusInternalServerError, gin.H{"error": "计算排名失败"})
} else {
// 二次回归:|a|越大排名越高a为负值
return compareQuadraticA(records[i].QuadraticA, records[j].QuadraticA)
}
})
// 计算排名(处理并列)
currentRank := 1
rankMap := make(map[uint]int)
lastValue := math.Inf(-1) // 初始极小值
for i, record := range records {
var currentValue float64
var isSet bool
switch regType {
case models.LinearRegression:
if record.Slope != nil {
currentValue = *record.Slope
isSet = true
}
case models.QuadraticRegression:
if record.QuadraticA != nil {
currentValue = math.Abs(*record.QuadraticA)
isSet = true
}
}
if !isSet {
continue // 跳过空值(理论上不应出现)
}
if i == 0 {
rankMap[record.TrainId] = currentRank
lastValue = currentValue
continue
}
// 值变化时更新排名(跳过重复值)
if math.Abs(currentValue-lastValue) > 1e-9 {
currentRank = i + 1
lastValue = currentValue
}
rankMap[record.TrainId] = currentRank
}
// 获取指定训练记录的排名
rank, exists := rankMap[trainId]
if !exists {
c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"})
return return
} }
// 计算实际排名 (并列排名)
currentRank := rank.BetterCount + 1
// 返回响应 // 返回响应
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "排名查询成功", "message": "排名查询成功",
"data": gin.H{ "data": gin.H{
"trainId": trainId, "trainId": trainId,
"type": regressionType, "type": regressionType,
"rank": rank, "rank": currentRank,
"total": len(records), "total": rank.Total,
}, },
}) })
} }
// 辅助函数:获取排序字段名
func getValueColumn(regType models.RegressionType) string {
switch regType {
case models.LinearRegression:
return "slope"
case models.QuadraticRegression:
return "ABS(quadratic_a)" // 计算绝对值
default:
return ""
}
}
// 辅助函数:获取比较操作符
func getComparisonOperator(regType models.RegressionType) string {
switch regType {
case models.LinearRegression:
return "<" // 线性回归:值越小越好
case models.QuadraticRegression:
return ">" // 二次回归:绝对值越大越好
default:
return ""
}
}
// 辅助函数:获取类型条件
func getTypeCondition(regType models.RegressionType) string {
switch regType {
case models.LinearRegression:
return "slope IS NOT NULL"
case models.QuadraticRegression:
return "quadratic_a IS NOT NULL AND quadratic_a < 0" // 确保是负值
default:
return ""
}
}
// 辅助函数:比较浮点指针(用于线性回归) // 辅助函数:比较浮点指针(用于线性回归)
func compareFloatPtr(a, b *float64, ascending bool) bool { func compareFloatPtr(a, b *float64, ascending bool) bool {
if a == nil && b == nil { if a == nil && b == nil {