fix: rank.
This commit is contained in:
@ -12,7 +12,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@ -741,99 +740,99 @@ func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 获取指定训练的基准值
|
||||||
|
var baseValue float64
|
||||||
|
baseQuery := tc.DB.Model(&models.RegressionResult{}).
|
||||||
|
Select(getValueColumn(regType)).
|
||||||
|
Where("train_id = ?", trainId)
|
||||||
|
|
||||||
// 获取排名数据
|
if err := baseQuery.Row().Scan(&baseValue); err != nil {
|
||||||
var records []models.RegressionResult
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
query := tc.DB.Model(&models.RegressionResult{})
|
c.JSON(http.StatusNotFound, gin.H{"error": "指定的训练数据不存在"})
|
||||||
switch regType {
|
return
|
||||||
case models.LinearRegression:
|
}
|
||||||
query = query.Where("slope IS NOT NULL").
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取基准数据失败"})
|
||||||
Select("train_id, slope, regression_type")
|
|
||||||
case models.QuadraticRegression:
|
|
||||||
query = query.Where("quadratic_a IS NOT NULL AND quadratic_a < 0").
|
|
||||||
Select("train_id, quadratic_a, regression_type")
|
|
||||||
}
|
|
||||||
if err := query.Find(&records).Error; err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(records) == 0 {
|
// 在数据库中进行排名计算
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"})
|
var rank struct {
|
||||||
|
BetterCount int64
|
||||||
|
Total int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// 动态生成比较条件
|
||||||
|
betterCondition := fmt.Sprintf("%s %s ?",
|
||||||
|
getValueColumn(regType),
|
||||||
|
getComparisonOperator(regType))
|
||||||
|
|
||||||
|
totalQuery := tc.DB.Model(&models.RegressionResult{}).
|
||||||
|
Where(getTypeCondition(regType))
|
||||||
|
|
||||||
|
if err := totalQuery.Count(&rank.Total).Error; err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "统计总数失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 排序逻辑
|
if err := tc.DB.Model(&models.RegressionResult{}).
|
||||||
sort.Slice(records, func(i, j int) bool {
|
Where(getTypeCondition(regType)).
|
||||||
if regType == models.LinearRegression {
|
Where(betterCondition, baseValue).
|
||||||
// 线性回归:斜率越小排名越高
|
Count(&rank.BetterCount).Error; err != nil {
|
||||||
return compareFloatPtr(records[i].Slope, records[j].Slope, true)
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "计算排名失败"})
|
||||||
} else {
|
|
||||||
// 二次回归:|a|越大排名越高(a为负值)
|
|
||||||
return compareQuadraticA(records[i].QuadraticA, records[j].QuadraticA)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// 计算排名(处理并列)
|
|
||||||
currentRank := 1
|
|
||||||
rankMap := make(map[uint]int)
|
|
||||||
lastValue := math.Inf(-1) // 初始极小值
|
|
||||||
|
|
||||||
for i, record := range records {
|
|
||||||
var currentValue float64
|
|
||||||
var isSet bool
|
|
||||||
|
|
||||||
switch regType {
|
|
||||||
case models.LinearRegression:
|
|
||||||
if record.Slope != nil {
|
|
||||||
currentValue = *record.Slope
|
|
||||||
isSet = true
|
|
||||||
}
|
|
||||||
case models.QuadraticRegression:
|
|
||||||
if record.QuadraticA != nil {
|
|
||||||
currentValue = math.Abs(*record.QuadraticA)
|
|
||||||
isSet = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isSet {
|
|
||||||
continue // 跳过空值(理论上不应出现)
|
|
||||||
}
|
|
||||||
|
|
||||||
if i == 0 {
|
|
||||||
rankMap[record.TrainId] = currentRank
|
|
||||||
lastValue = currentValue
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 值变化时更新排名(跳过重复值)
|
|
||||||
if math.Abs(currentValue-lastValue) > 1e-9 {
|
|
||||||
currentRank = i + 1
|
|
||||||
lastValue = currentValue
|
|
||||||
}
|
|
||||||
|
|
||||||
rankMap[record.TrainId] = currentRank
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取指定训练记录的排名
|
|
||||||
rank, exists := rankMap[trainId]
|
|
||||||
if !exists {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 计算实际排名 (并列排名)
|
||||||
|
currentRank := rank.BetterCount + 1
|
||||||
|
|
||||||
// 返回响应
|
// 返回响应
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "排名查询成功",
|
"message": "排名查询成功",
|
||||||
"data": gin.H{
|
"data": gin.H{
|
||||||
"trainId": trainId,
|
"trainId": trainId,
|
||||||
"type": regressionType,
|
"type": regressionType,
|
||||||
"rank": rank,
|
"rank": currentRank,
|
||||||
"total": len(records),
|
"total": rank.Total,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 辅助函数:获取排序字段名
|
||||||
|
func getValueColumn(regType models.RegressionType) string {
|
||||||
|
switch regType {
|
||||||
|
case models.LinearRegression:
|
||||||
|
return "slope"
|
||||||
|
case models.QuadraticRegression:
|
||||||
|
return "ABS(quadratic_a)" // 计算绝对值
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:获取比较操作符
|
||||||
|
func getComparisonOperator(regType models.RegressionType) string {
|
||||||
|
switch regType {
|
||||||
|
case models.LinearRegression:
|
||||||
|
return "<" // 线性回归:值越小越好
|
||||||
|
case models.QuadraticRegression:
|
||||||
|
return ">" // 二次回归:绝对值越大越好
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:获取类型条件
|
||||||
|
func getTypeCondition(regType models.RegressionType) string {
|
||||||
|
switch regType {
|
||||||
|
case models.LinearRegression:
|
||||||
|
return "slope IS NOT NULL"
|
||||||
|
case models.QuadraticRegression:
|
||||||
|
return "quadratic_a IS NOT NULL AND quadratic_a < 0" // 确保是负值
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 辅助函数:比较浮点指针(用于线性回归)
|
// 辅助函数:比较浮点指针(用于线性回归)
|
||||||
func compareFloatPtr(a, b *float64, ascending bool) bool {
|
func compareFloatPtr(a, b *float64, ascending bool) bool {
|
||||||
if a == nil && b == nil {
|
if a == nil && b == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user