diff --git a/controllers/step_train.go b/controllers/step_train.go index 33e993f..fa22719 100644 --- a/controllers/step_train.go +++ b/controllers/step_train.go @@ -232,8 +232,8 @@ func performLinearRegression(averages []map[float64]float64) models.RegressionRe } // 创建结果 - slope := r.Coeff(0) - intercept := r.Coeff(1) + slope := r.Coeff(1) + intercept := r.Coeff(0) r2 := r.R2 return models.RegressionResult{ RegressionType: models.LinearRegression, @@ -711,28 +711,24 @@ func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) { }) } -// 获取训练记录的排名 func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) { - // 解析参数 + // 参数解析 trainIdStr := c.Param("trainId") - regressionTypeStr := c.Query("type") - regressionType, err := strconv.Atoi(regressionTypeStr) // 字符串转整型 + regressionType, err := strconv.Atoi(regressionTypeStr) if err != nil { - // 转换失败时返回400错误 c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"}) return } - regType := models.RegressionType(regressionType) - // 验证回归类型 + regType := models.RegressionType(regressionType) if regType != models.LinearRegression && regType != models.QuadraticRegression { c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"}) return } - // 转换trainId + // 转换训练ID tid, err := strconv.ParseUint(trainIdStr, 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"}) @@ -746,70 +742,80 @@ func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) { return } - // 获取所有记录用于排名 + // 获取排名数据 var records []models.RegressionResult query := tc.DB.Model(&models.RegressionResult{}) switch regType { case models.LinearRegression: - query = query.Where("slope IS NOT NULL").Select("train_id, slope,regression_type") + query = query.Where("slope IS NOT NULL"). + Select("train_id, slope, regression_type") case models.QuadraticRegression: - query = query.Where("quadratic_a IS NOT NULL"). - Select("train_id, ABS(quadratic_a),regression_type") + query = query.Where("quadratic_a IS NOT NULL AND quadratic_a < 0"). + Select("train_id, quadratic_a, regression_type") } - if err := query.Debug().Find(&records).Error; err != nil { + if err := query.Find(&records).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"}) return } - // 处理无数据情况 if len(records) == 0 { c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"}) return } + // 排序逻辑 sort.Slice(records, func(i, j int) bool { - // 处理空指针情况 - slopeI := records[i].Slope - slopeJ := records[j].Slope - - if slopeI == nil && slopeJ == nil { - return false // 两者均为空时视为相等 + if regType == models.LinearRegression { + // 线性回归:斜率越小排名越高 + return compareFloatPtr(records[i].Slope, records[j].Slope, true) + } else { + // 二次回归:|a|越大排名越高(a为负值) + return compareQuadraticA(records[i].QuadraticA, records[j].QuadraticA) } - if slopeI == nil { - return false // 空值视为极大值(排在最后) - } - if slopeJ == nil { - return true // 非空值始终排在前 - } - - return *slopeI > *slopeJ }) // 计算排名(处理并列) currentRank := 1 rankMap := make(map[uint]int) + lastValue := math.Inf(-1) // 初始极小值 + for i, record := range records { - // 处理第一个记录 + var currentValue float64 + var isSet bool + + switch regType { + case models.LinearRegression: + if record.Slope != nil { + currentValue = *record.Slope + isSet = true + } + case models.QuadraticRegression: + if record.QuadraticA != nil { + currentValue = math.Abs(*record.QuadraticA) + isSet = true + } + } + + if !isSet { + continue // 跳过空值(理论上不应出现) + } + if i == 0 { rankMap[record.TrainId] = currentRank + lastValue = currentValue continue } - if regType == models.LinearRegression { - if records[i].Slope != records[i-1].Slope { - currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1 - } - rankMap[record.TrainId] = currentRank - } else { - if records[i].QuadraticA != records[i-1].QuadraticA { - currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1 - } - rankMap[record.TrainId] = currentRank + // 值变化时更新排名(跳过重复值) + if math.Abs(currentValue-lastValue) > 1e-9 { + currentRank = i + 1 + lastValue = currentValue } - // 检测值是否变化 + + rankMap[record.TrainId] = currentRank } - // 获取当前训练记录的排名 + // 获取指定训练记录的排名 rank, exists := rankMap[trainId] if !exists { c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"}) @@ -827,3 +833,37 @@ func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) { }, }) } + +// 辅助函数:比较浮点指针(用于线性回归) +func compareFloatPtr(a, b *float64, ascending bool) bool { + if a == nil && b == nil { + return false + } + if a == nil { + return false // 空值排最后 + } + if b == nil { + return true // 非空值排前 + } + + if ascending { + return *a < *b + } + return *a > *b +} + +// 辅助函数:比较二次项系数(用于二次回归) +func compareQuadraticA(a, b *float64) bool { + if a == nil && b == nil { + return false + } + if a == nil { + return false + } + if b == nil { + return true + } + + // 比较绝对值(a和b都是负值,所以取绝对值后大的排前面) + return math.Abs(*a) > math.Abs(*b) +}