fix: linear slope.
This commit is contained in:
@ -232,8 +232,8 @@ func performLinearRegression(averages []map[float64]float64) models.RegressionRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建结果
|
// 创建结果
|
||||||
slope := r.Coeff(0)
|
slope := r.Coeff(1)
|
||||||
intercept := r.Coeff(1)
|
intercept := r.Coeff(0)
|
||||||
r2 := r.R2
|
r2 := r.R2
|
||||||
return models.RegressionResult{
|
return models.RegressionResult{
|
||||||
RegressionType: models.LinearRegression,
|
RegressionType: models.LinearRegression,
|
||||||
@ -711,28 +711,24 @@ func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取训练记录的排名
|
|
||||||
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
|
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
|
||||||
// 解析参数
|
// 参数解析
|
||||||
trainIdStr := c.Param("trainId")
|
trainIdStr := c.Param("trainId")
|
||||||
|
|
||||||
regressionTypeStr := c.Query("type")
|
regressionTypeStr := c.Query("type")
|
||||||
regressionType, err := strconv.Atoi(regressionTypeStr) // 字符串转整型
|
regressionType, err := strconv.Atoi(regressionTypeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 转换失败时返回400错误
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
regType := models.RegressionType(regressionType)
|
|
||||||
|
|
||||||
// 验证回归类型
|
// 验证回归类型
|
||||||
|
regType := models.RegressionType(regressionType)
|
||||||
if regType != models.LinearRegression && regType != models.QuadraticRegression {
|
if regType != models.LinearRegression && regType != models.QuadraticRegression {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 转换trainId
|
// 转换训练ID
|
||||||
tid, err := strconv.ParseUint(trainIdStr, 10, 64)
|
tid, err := strconv.ParseUint(trainIdStr, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
||||||
@ -746,70 +742,80 @@ func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取所有记录用于排名
|
// 获取排名数据
|
||||||
var records []models.RegressionResult
|
var records []models.RegressionResult
|
||||||
query := tc.DB.Model(&models.RegressionResult{})
|
query := tc.DB.Model(&models.RegressionResult{})
|
||||||
switch regType {
|
switch regType {
|
||||||
case models.LinearRegression:
|
case models.LinearRegression:
|
||||||
query = query.Where("slope IS NOT NULL").Select("train_id, slope,regression_type")
|
query = query.Where("slope IS NOT NULL").
|
||||||
|
Select("train_id, slope, regression_type")
|
||||||
case models.QuadraticRegression:
|
case models.QuadraticRegression:
|
||||||
query = query.Where("quadratic_a IS NOT NULL").
|
query = query.Where("quadratic_a IS NOT NULL AND quadratic_a < 0").
|
||||||
Select("train_id, ABS(quadratic_a),regression_type")
|
Select("train_id, quadratic_a, regression_type")
|
||||||
}
|
}
|
||||||
if err := query.Debug().Find(&records).Error; err != nil {
|
if err := query.Find(&records).Error; err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询排名数据失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理无数据情况
|
|
||||||
if len(records) == 0 {
|
if len(records) == 0 {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "无可用数据计算排名"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 排序逻辑
|
||||||
sort.Slice(records, func(i, j int) bool {
|
sort.Slice(records, func(i, j int) bool {
|
||||||
// 处理空指针情况
|
if regType == models.LinearRegression {
|
||||||
slopeI := records[i].Slope
|
// 线性回归:斜率越小排名越高
|
||||||
slopeJ := records[j].Slope
|
return compareFloatPtr(records[i].Slope, records[j].Slope, true)
|
||||||
|
} else {
|
||||||
if slopeI == nil && slopeJ == nil {
|
// 二次回归:|a|越大排名越高(a为负值)
|
||||||
return false // 两者均为空时视为相等
|
return compareQuadraticA(records[i].QuadraticA, records[j].QuadraticA)
|
||||||
}
|
}
|
||||||
if slopeI == nil {
|
|
||||||
return false // 空值视为极大值(排在最后)
|
|
||||||
}
|
|
||||||
if slopeJ == nil {
|
|
||||||
return true // 非空值始终排在前
|
|
||||||
}
|
|
||||||
|
|
||||||
return *slopeI > *slopeJ
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// 计算排名(处理并列)
|
// 计算排名(处理并列)
|
||||||
currentRank := 1
|
currentRank := 1
|
||||||
rankMap := make(map[uint]int)
|
rankMap := make(map[uint]int)
|
||||||
|
lastValue := math.Inf(-1) // 初始极小值
|
||||||
|
|
||||||
for i, record := range records {
|
for i, record := range records {
|
||||||
// 处理第一个记录
|
var currentValue float64
|
||||||
|
var isSet bool
|
||||||
|
|
||||||
|
switch regType {
|
||||||
|
case models.LinearRegression:
|
||||||
|
if record.Slope != nil {
|
||||||
|
currentValue = *record.Slope
|
||||||
|
isSet = true
|
||||||
|
}
|
||||||
|
case models.QuadraticRegression:
|
||||||
|
if record.QuadraticA != nil {
|
||||||
|
currentValue = math.Abs(*record.QuadraticA)
|
||||||
|
isSet = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSet {
|
||||||
|
continue // 跳过空值(理论上不应出现)
|
||||||
|
}
|
||||||
|
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
rankMap[record.TrainId] = currentRank
|
rankMap[record.TrainId] = currentRank
|
||||||
|
lastValue = currentValue
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if regType == models.LinearRegression {
|
// 值变化时更新排名(跳过重复值)
|
||||||
if records[i].Slope != records[i-1].Slope {
|
if math.Abs(currentValue-lastValue) > 1e-9 {
|
||||||
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1
|
currentRank = i + 1
|
||||||
}
|
lastValue = currentValue
|
||||||
rankMap[record.TrainId] = currentRank
|
|
||||||
} else {
|
|
||||||
if records[i].QuadraticA != records[i-1].QuadraticA {
|
|
||||||
currentRank = i + 1 // 值变化时,当前排名 = 索引 + 1
|
|
||||||
}
|
|
||||||
rankMap[record.TrainId] = currentRank
|
|
||||||
}
|
}
|
||||||
// 检测值是否变化
|
|
||||||
|
rankMap[record.TrainId] = currentRank
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取当前训练记录的排名
|
// 获取指定训练记录的排名
|
||||||
rank, exists := rankMap[trainId]
|
rank, exists := rankMap[trainId]
|
||||||
if !exists {
|
if !exists {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "训练记录未包含在排名中"})
|
||||||
@ -827,3 +833,37 @@ func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 辅助函数:比较浮点指针(用于线性回归)
|
||||||
|
func compareFloatPtr(a, b *float64, ascending bool) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a == nil {
|
||||||
|
return false // 空值排最后
|
||||||
|
}
|
||||||
|
if b == nil {
|
||||||
|
return true // 非空值排前
|
||||||
|
}
|
||||||
|
|
||||||
|
if ascending {
|
||||||
|
return *a < *b
|
||||||
|
}
|
||||||
|
return *a > *b
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:比较二次项系数(用于二次回归)
|
||||||
|
func compareQuadraticA(a, b *float64) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 比较绝对值(a和b都是负值,所以取绝对值后大的排前面)
|
||||||
|
return math.Abs(*a) > math.Abs(*b)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user