diff --git a/controllers/step_train.go b/controllers/step_train.go index fa22719..f4ac3e5 100644 --- a/controllers/step_train.go +++ b/controllers/step_train.go @@ -12,7 +12,6 @@ import ( "log" "math" "net/http" - "sort" "strconv" "strings" ) @@ -741,99 +740,99 @@ func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()}) return } + // 获取指定训练的基准值 + var baseValue float64 + baseQuery := tc.DB.Model(&models.RegressionResult{}). + Select(getValueColumn(regType)). + Where("train_id = ?", trainId) - // 获取排名数据 - var records []models.RegressionResult - query := tc.DB.Model(&models.RegressionResult{}) - switch regType { - case models.LinearRegression: - query = query.Where("slope IS NOT NULL"). - Select("train_id, slope, regression_type") - case models.QuadraticRegression: - query = query.Where("quadratic_a IS NOT NULL AND quadratic_a < 0"). - Select("train_id, quadratic_a, regression_type") - } - if err := query.Find(&records).Error; err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"}) + if err := baseQuery.Row().Scan(&baseValue); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, gin.H{"error": "指定的训练数据不存在"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "获取基准数据失败"}) return } - if len(records) == 0 { - c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"}) + // 在数据库中进行排名计算 + var rank struct { + BetterCount int64 + Total int64 + } + + // 动态生成比较条件 + betterCondition := fmt.Sprintf("%s %s ?", + getValueColumn(regType), + getComparisonOperator(regType)) + + totalQuery := tc.DB.Model(&models.RegressionResult{}). + Where(getTypeCondition(regType)) + + if err := totalQuery.Count(&rank.Total).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "统计总数失败"}) return } - // 排序逻辑 - sort.Slice(records, func(i, j int) bool { - if regType == models.LinearRegression { - // 线性回归:斜率越小排名越高 - return compareFloatPtr(records[i].Slope, records[j].Slope, true) - } else { - // 二次回归:|a|越大排名越高(a为负值) - return compareQuadraticA(records[i].QuadraticA, records[j].QuadraticA) - } - }) - - // 计算排名(处理并列) - currentRank := 1 - rankMap := make(map[uint]int) - lastValue := math.Inf(-1) // 初始极小值 - - for i, record := range records { - var currentValue float64 - var isSet bool - - switch regType { - case models.LinearRegression: - if record.Slope != nil { - currentValue = *record.Slope - isSet = true - } - case models.QuadraticRegression: - if record.QuadraticA != nil { - currentValue = math.Abs(*record.QuadraticA) - isSet = true - } - } - - if !isSet { - continue // 跳过空值(理论上不应出现) - } - - if i == 0 { - rankMap[record.TrainId] = currentRank - lastValue = currentValue - continue - } - - // 值变化时更新排名(跳过重复值) - if math.Abs(currentValue-lastValue) > 1e-9 { - currentRank = i + 1 - lastValue = currentValue - } - - rankMap[record.TrainId] = currentRank - } - - // 获取指定训练记录的排名 - rank, exists := rankMap[trainId] - if !exists { - c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"}) + if err := tc.DB.Model(&models.RegressionResult{}). + Where(getTypeCondition(regType)). + Where(betterCondition, baseValue). + Count(&rank.BetterCount).Error; err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "计算排名失败"}) return } + // 计算实际排名 (并列排名) + currentRank := rank.BetterCount + 1 + // 返回响应 c.JSON(http.StatusOK, gin.H{ "message": "排名查询成功", "data": gin.H{ "trainId": trainId, "type": regressionType, - "rank": rank, - "total": len(records), + "rank": currentRank, + "total": rank.Total, }, }) } +// 辅助函数:获取排序字段名 +func getValueColumn(regType models.RegressionType) string { + switch regType { + case models.LinearRegression: + return "slope" + case models.QuadraticRegression: + return "ABS(quadratic_a)" // 计算绝对值 + default: + return "" + } +} + +// 辅助函数:获取比较操作符 +func getComparisonOperator(regType models.RegressionType) string { + switch regType { + case models.LinearRegression: + return "<" // 线性回归:值越小越好 + case models.QuadraticRegression: + return ">" // 二次回归:绝对值越大越好 + default: + return "" + } +} + +// 辅助函数:获取类型条件 +func getTypeCondition(regType models.RegressionType) string { + switch regType { + case models.LinearRegression: + return "slope IS NOT NULL" + case models.QuadraticRegression: + return "quadratic_a IS NOT NULL AND quadratic_a < 0" // 确保是负值 + default: + return "" + } +} + // 辅助函数:比较浮点指针(用于线性回归) func compareFloatPtr(a, b *float64, ascending bool) bool { if a == nil && b == nil {