Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9a1dedf425 | |||
| d7885c442f | |||
| 85ec3cba4a | |||
| 1c38601fe0 | |||
| 876916010f | |||
| 1a252a12be | |||
| 8d8dd26a2c | |||
| a14c553736 | |||
| 563fdd8a7e | |||
| d3e87fda67 | |||
| 3078c13e14 | |||
| 7a2a44e327 | |||
| b00767dfcd | |||
| 2aa22b1385 | |||
| 07963218cd | |||
| 0852f4bc23 | |||
| 19148d7d35 | |||
| 4e10359a5b | |||
| 8510d19baa |
Executable
+10
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
# docker exec hr_data_analyzer_db_1 mysqldump -uroot -proot training_db > train.sql
|
||||
# docker exec hr_data_analyzer_db_1 pg_dump -U postgres training_db > train.sql
|
||||
|
||||
docker exec -e PGPASSWORD=root hr_data_analyzer_db_1 pg_dump \
|
||||
-U postgres \
|
||||
--data-only \
|
||||
--inserts \
|
||||
-t step_heart_rates -t step_stride_freqs -t step_train_records \
|
||||
training_db > data_only.sql
|
||||
@@ -1,6 +0,0 @@
|
||||
database:
|
||||
host: localhost #when use docker change to "db"
|
||||
port: 5432
|
||||
user: postgres
|
||||
password: root
|
||||
name: training_db
|
||||
@@ -0,0 +1,270 @@
|
||||
// controllers/ai.go
|
||||
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"context" // 在此处添加 context 导入
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"hr_receiver/util"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 配置文件 (与 main.go 保持一致)
|
||||
const (
|
||||
BaseURL = "https://tokenhub.tencentmaas.com/v1/"
|
||||
APIKey = "sk-KJNOFMltNzhSKh2IxW3G3MKmZF3q2RrOlvSk497CfTHp1Z4u" // 请替换为实际的 API Key
|
||||
Model = "deepseek-v4-flash"
|
||||
)
|
||||
|
||||
// readDocxContent 读取 .docx 文件并将其转换为结构化文本
|
||||
// 修改为先保存临时文件再读取
|
||||
func readDocxContent(fileHeader *multipart.FileHeader) (string, error) {
|
||||
// 1. 创建临时文件
|
||||
tempFile, err := os.CreateTemp("", "upload_*.docx")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary file: %w", err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name()) // 确保函数结束时删除临时文件
|
||||
defer tempFile.Close()
|
||||
|
||||
// 2. 打开上传的文件流
|
||||
src, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open uploaded file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// 3. 将上传的文件内容复制到临时文件
|
||||
_, err = io.Copy(tempFile, src)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to copy file to temporary location: %w", err)
|
||||
}
|
||||
|
||||
// 4. 获取临时文件的完整路径
|
||||
tempFilePath := tempFile.Name()
|
||||
str, err := util.DocxToStructuredPrompt(tempFilePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse docx with go-docx: %w", err)
|
||||
}
|
||||
// 注意:表格、图片等复杂元素的处理可能需要更复杂的逻辑,这里仅处理简单文本
|
||||
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// readCSVContent 读取 .csv 文件内容
|
||||
// 修改为先保存临时文件再读取
|
||||
// readCSVContent 读取 .csv 文件内容
|
||||
// 修改为先保存临时文件再读取,并增加了数据压缩逻辑以解决 token 超长问题
|
||||
func readCSVContent(fileHeader *multipart.FileHeader) (string, error) {
|
||||
// 1. 创建临时文件
|
||||
tempFile, err := os.CreateTemp("", "upload_*.csv")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary file: %w", err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name()) // 确保函数结束时删除临时文件
|
||||
defer tempFile.Close()
|
||||
|
||||
// 2. 打开上传的文件流
|
||||
src, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open uploaded file: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// 3. 将上传的文件内容复制到临时文件
|
||||
_, err = io.Copy(tempFile, src)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to copy file to temporary location: %w", err)
|
||||
}
|
||||
|
||||
// 4. 读取临时文件内容
|
||||
content, err := ioutil.ReadFile(tempFile.Name())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read CSV content from temporary file: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
var compressedLines []string
|
||||
|
||||
for i, line := range lines {
|
||||
// 1. 必须保留第一行(表头),让 AI 知道每一列是什么
|
||||
if i == 0 {
|
||||
compressedLines = append(compressedLines, line)
|
||||
continue
|
||||
}
|
||||
|
||||
// 2. 跳过空行
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. 抽样逻辑:每 4 行保留 1 行
|
||||
// i=1 是数据第1行,i=2 是数据第2行...
|
||||
// (i-1)%4 == 0 意味着:数据第1, 5, 9, 13... 行会被保留
|
||||
if (i-1)%4 == 0 {
|
||||
compressedLines = append(compressedLines, line)
|
||||
}
|
||||
}
|
||||
// --- 修改逻辑结束 ---
|
||||
|
||||
// 将处理后的行重新组合成字符串
|
||||
resultContent := strings.Join(compressedLines, "\n")
|
||||
// --- 新增逻辑结束 ---
|
||||
|
||||
return resultContent, nil
|
||||
}
|
||||
|
||||
// buildAnalysisPrompt 构建发送给 AI 的提示词
|
||||
func buildAnalysisPrompt(teachingPlanContent, heartRateContent string) string {
|
||||
return fmt.Sprintf(`请根据以下体育课堂的教案和心率监测数据,生成一份详细的课堂分析报告:
|
||||
|
||||
## 教案内容:
|
||||
%s
|
||||
|
||||
## 心率监测数据:
|
||||
%s
|
||||
|
||||
这是一份幼儿园体育课的教案和课程心率监测数据,请帮对照分析课程教学效果,运动量和运动负荷情况是否科学,并提出课程设计的优化方案。
|
||||
|
||||
优化方案参考如下格式,教学过程需要详细一些:
|
||||
# 幼儿体育教案(华侨大学版本)
|
||||
|
||||
| 项目 | 内容 |
|
||||
| ------------ | -------------------------------- |
|
||||
| **课程名** | |
|
||||
| **年段** | 小 中 大 |
|
||||
| **教师姓名** | |
|
||||
| **时间** | 年 月 日 |
|
||||
| **地点** | |
|
||||
| **人数** | 男: 女: |
|
||||
| **时长** | 分钟 |
|
||||
| **天气预报** | 晴 雨 阴;温度 ℃ |
|
||||
| **器材准备** | |
|
||||
|
||||
## 教学目标
|
||||
|
||||
| 类型 | 目标 |
|
||||
| -------- | ------------ |
|
||||
| **体能目标** | |
|
||||
| **技能目标** | |
|
||||
| **情感目标** | |
|
||||
|
||||
## 教学过程
|
||||
|
||||
| 阶段 | 阶段 | 项目名称 | 引导语及教学方法 | 队形/站位/留意点 | 目标心率区间 | 时间(分) |
|
||||
| ---------- | -------- | ----------------------------- | ------------------------ | --------------------- | ------------ | ---------- |
|
||||
| **准备部分** | 热身 | | | | | 3 |
|
||||
| | 注意力游戏 | | | | | 3 |
|
||||
| **正课部分** | 基本素质练习及常规意识培养环节 | | | | | 5 |
|
||||
| | 复习环节 | | | | | 5 |
|
||||
| | 新授环节 | | | | | 8 |
|
||||
| **结束部分** | 社会性及情感目标游戏 | | | | | 4 |
|
||||
| | 整理放松 | | | | | 2 |
|
||||
|
||||
请以专业体育教师的视角,提供详细的数据分析和教学建议。请直接输出报告内容,不要包含“好的”、“收到”、“作为一名...”等任何开场白或客套话。`, teachingPlanContent, heartRateContent)
|
||||
}
|
||||
|
||||
// callAIForAnalysis 调用大模型进行分析
|
||||
func callAIForAnalysis(prompt string) (string, error) {
|
||||
sizeInBytes := len(prompt)
|
||||
sizeInKB := float64(sizeInBytes) / 1024.0
|
||||
|
||||
// 在日志中打印大小,保留两位小数
|
||||
log.Printf("=== 发送给 AI 的内容大小: %.2f KB (%d 字节) ===", sizeInKB, sizeInBytes)
|
||||
config := openai.DefaultConfig(APIKey)
|
||||
config.BaseURL = BaseURL
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: Model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: prompt,
|
||||
},
|
||||
},
|
||||
Temperature: 0.6, // 可调整
|
||||
TopP: 0.6, // 可调整
|
||||
MaxTokens: 4000, // 根据需要调整
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("API call failed: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices returned from API")
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// AnalyzeByAI Gin 控制器方法
|
||||
func (tc *TrainingController) AnalyzeByAI(c *gin.Context) {
|
||||
// 1. 解析多部分表单请求
|
||||
form, err := c.MultipartForm()
|
||||
if err != nil {
|
||||
log.Printf("Error parsing multipart form: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Failed to parse form: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 获取文件列表
|
||||
docxFiles := form.File["teaching_plan"] // 假设前端字段名为 'teaching_plan'
|
||||
csvFiles := form.File["heart_rate_data"] // 假设前端字段名为 'heart_rate_data'
|
||||
|
||||
if len(docxFiles) == 0 || len(csvFiles) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Missing required files: teaching_plan (.docx) or heart_rate_data (.csv)"})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 读取文件内容
|
||||
// 注意:这里我们只取第一个上传的文件
|
||||
teachingPlanFileHeader := docxFiles[0]
|
||||
heartRateFileHeader := csvFiles[0]
|
||||
|
||||
teachingPlanContent, err := readDocxContent(teachingPlanFileHeader)
|
||||
if err != nil {
|
||||
log.Printf("Error reading teaching plan file (%s): %v", teachingPlanFileHeader.Filename, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to process teaching plan file: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
heartRateContent, err := readCSVContent(heartRateFileHeader)
|
||||
if err != nil {
|
||||
log.Printf("Error reading heart rate file (%s): %v", heartRateFileHeader.Filename, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to process heart rate file: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 构建 Prompt
|
||||
prompt := buildAnalysisPrompt(teachingPlanContent, heartRateContent)
|
||||
|
||||
// 5. 调用 AI 分析
|
||||
analysisResult, err := callAIForAnalysis(prompt)
|
||||
if err != nil {
|
||||
log.Printf("Error calling AI for analysis: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("AI analysis failed: %v", err)})
|
||||
return
|
||||
}
|
||||
//outputFile := ".md"
|
||||
//ioutil.WriteFile(outputFile, []byte(analysisResult), 0644)
|
||||
|
||||
// 6. 返回结果
|
||||
// 方式一:返回 JSON 结构
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "success",
|
||||
"data": analysisResult,
|
||||
})
|
||||
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"hr_receiver/config"
|
||||
"hr_receiver/models"
|
||||
"hr_receiver/util"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" form:"username"`
|
||||
Password string `json:"password" form:"password"`
|
||||
}
|
||||
|
||||
type AuthResponse struct {
|
||||
Token string `json:"token"`
|
||||
User models.User `json:"user"`
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
func Register(c *gin.Context) {
|
||||
var req RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户名是否已存在
|
||||
var existingUser models.User
|
||||
if result := config.DB.Where("username = ?", req.Username).First(&existingUser); result.Error == nil {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "Username already exists"})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建新用户
|
||||
user := models.User{
|
||||
Username: req.Username,
|
||||
Password: req.Password, // BeforeCreate钩子会自动加密
|
||||
}
|
||||
|
||||
if result := config.DB.Create(&user); result.Error != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成Token
|
||||
token, err := util.GenerateToken(user.ID, user.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, AuthResponse{
|
||||
Token: token,
|
||||
User: user,
|
||||
})
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
func Login(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 查找用户
|
||||
var user models.User
|
||||
result := config.DB.Where("username = ?", req.Username).First(&user)
|
||||
|
||||
if result.Error != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid username or password"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !user.CheckPassword(req.Password) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid username or password"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := util.GenerateToken(user.ID, user.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, AuthResponse{
|
||||
Token: token,
|
||||
User: user,
|
||||
})
|
||||
}
|
||||
|
||||
// GetProfile 获取用户信息(需要认证)
|
||||
func GetProfile(c *gin.Context) {
|
||||
userID, exists := c.Get("userID")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
var user models.User
|
||||
if result := config.DB.First(&user, userID); result.Error != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
@@ -0,0 +1,906 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sajari/regression"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"hr_receiver/config"
|
||||
"hr_receiver/models"
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type StepTrainingController struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func NewStepTrainingController() *StepTrainingController {
|
||||
return &StepTrainingController{DB: config.DB}
|
||||
}
|
||||
|
||||
// 接收训练记录
|
||||
func (tc *StepTrainingController) CreateTrainingRecord(c *gin.Context) {
|
||||
var record models.StepTrainRecord
|
||||
|
||||
// 绑定并验证JSON数据
|
||||
if err := c.ShouldBindJSON(&record); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
username, exists := c.Get("username")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无法获取用户信息,请重新登录"})
|
||||
return
|
||||
}
|
||||
record.Username = username.(string)
|
||||
|
||||
// 使用事务保存数据[4](@ref)
|
||||
err := tc.DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 保存主记录
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "train_id"}}, // 指定冲突的列
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"max_heart_rate": record.MaxHeartRate,
|
||||
"start_time": record.StartTime,
|
||||
"end_time": record.EndTime,
|
||||
"duration": record.Duration,
|
||||
"dead_zone": record.DeadZone,
|
||||
"name": record.Name,
|
||||
"evaluation": record.Evaluation,
|
||||
}),
|
||||
}).Omit("HeartRates", "StrideFreqs").Create(&record).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存关联的心率数据
|
||||
for i := range record.HeartRates {
|
||||
if err := tx.Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{"heart_rate_type": record.HeartRates[i].HeartRateType, "value": record.HeartRates[i].Value, "time": record.HeartRates[i].Time}),
|
||||
},
|
||||
).Create(&record.HeartRates[i]).Error; err != nil {
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i := range record.StrideFreqs {
|
||||
if err := tx.Clauses(
|
||||
clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "identifier"}}, // 指定冲突的列
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{"value": record.StrideFreqs[i].Value, "time": record.StrideFreqs[i].Time}),
|
||||
},
|
||||
).Create(&record.StrideFreqs[i]).Error; err != nil {
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// ====== 新增部分:启动异步回归计算 ======
|
||||
go func() {
|
||||
// 查询完整数据(需要关联的心率和步频数据)
|
||||
var fullRecord models.StepTrainRecord
|
||||
if err := tc.DB.
|
||||
Where("train_id = ?", record.TrainId).
|
||||
Preload("HeartRates", "heart_rate_type = ?", 1). // 只要有效心率
|
||||
Preload("StrideFreqs", "predict_value = ?", 1). // 只要有效步频
|
||||
First(&fullRecord).Error; err != nil {
|
||||
log.Printf("训练记录%d查询失败,无法计算回归: %v", record.TrainId, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查数据是否满足计算条件
|
||||
if len(fullRecord.HeartRates) == 0 || len(fullRecord.StrideFreqs) == 0 {
|
||||
log.Printf("训练记录%d缺少心率或步频数据,跳过回归计算", record.TrainId)
|
||||
return
|
||||
}
|
||||
|
||||
// 计算并保存回归结果
|
||||
if _, err := tc.GetOrCalculateRegression(fullRecord.TrainId); err != nil {
|
||||
log.Printf("训练记录%d回归计算失败: %v", fullRecord.TrainId, err)
|
||||
} else {
|
||||
log.Printf("训练记录%d回归结果已保存", fullRecord.TrainId)
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"message": "数据保存成功",
|
||||
"id": record.TrainId,
|
||||
})
|
||||
}
|
||||
|
||||
func (tc *StepTrainingController) GetTrainingRecords(c *gin.Context) {
|
||||
// 定义分页参数结构
|
||||
type PaginationParams struct {
|
||||
PageNum int `form:"pageNum,default=1"` // 页码,默认第一页
|
||||
PageSize int `form:"pageSize,default=10"` // 每页数量,默认10条
|
||||
}
|
||||
username, exists := c.Get("username")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无法获取用户信息,请重新登录"})
|
||||
return
|
||||
}
|
||||
|
||||
var params PaginationParams
|
||||
if err := c.ShouldBindQuery(¶ms); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证分页参数有效性
|
||||
if params.PageNum < 1 {
|
||||
params.PageNum = 1
|
||||
}
|
||||
if params.PageSize < 1 || params.PageSize > 100 {
|
||||
params.PageSize = 10
|
||||
}
|
||||
|
||||
// 计算偏移量
|
||||
offset := (params.PageNum - 1) * params.PageSize
|
||||
|
||||
var (
|
||||
records []models.StepTrainRecord
|
||||
totalRows int64
|
||||
)
|
||||
|
||||
// 获取总记录数
|
||||
if err := tc.DB.Model(&models.StepTrainRecord{}).Where("username = ?", username).Count(&totalRows).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取记录总数失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 查询分页数据(按开始时间倒序排列)
|
||||
result := tc.DB.Where("username = ?", username).
|
||||
Order("start_time DESC"). // 按开始时间倒序
|
||||
Offset(offset).
|
||||
Limit(params.PageSize).
|
||||
Find(&records)
|
||||
|
||||
if result.Error != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 计算总页数
|
||||
totalPages := int(math.Ceil(float64(totalRows) / float64(params.PageSize)))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "查询成功",
|
||||
"data": gin.H{
|
||||
"list": records,
|
||||
"pagination": gin.H{
|
||||
"currentPage": params.PageNum,
|
||||
"pageSize": params.PageSize,
|
||||
"totalPage": totalPages,
|
||||
"totalList": totalRows,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
func (tc *StepTrainingController) GetTrainingRecordByTrainId(c *gin.Context) {
|
||||
// 从URL路径参数获取trainId
|
||||
trainId := c.Param("trainId")
|
||||
if trainId == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "训练ID不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 将字符串trainId转换为uint类型
|
||||
tid, err := strconv.ParseInt(trainId, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID格式"})
|
||||
return
|
||||
}
|
||||
|
||||
var record models.StepTrainRecord
|
||||
|
||||
// 查询主记录并预加载关联的心率和步频数据
|
||||
result := tc.DB.Where("train_id = ?", uint(tid)).
|
||||
Preload("HeartRates").
|
||||
Preload("StrideFreqs").
|
||||
First(&record)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "训练记录不存在"})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": result.Error.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 成功返回数据
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "查询成功",
|
||||
"data": record,
|
||||
})
|
||||
}
|
||||
|
||||
// 定义结构体
|
||||
type SpeedSegment struct {
|
||||
Duration float64
|
||||
Speed float64
|
||||
}
|
||||
|
||||
// 实现线性回归算法
|
||||
func performLinearRegression(averages []map[float64]float64) models.RegressionResult {
|
||||
if len(averages) == 0 {
|
||||
return models.RegressionResult{
|
||||
Equation: "无数据",
|
||||
}
|
||||
}
|
||||
|
||||
// 收集数据点
|
||||
var points []struct{ x, y float64 }
|
||||
for _, m := range averages {
|
||||
for x, y := range m {
|
||||
points = append(points, struct{ x, y float64 }{x, y})
|
||||
}
|
||||
}
|
||||
|
||||
// 使用回归库计算
|
||||
r := new(regression.Regression)
|
||||
r.SetObserved("y")
|
||||
r.SetVar(0, "x")
|
||||
|
||||
for _, p := range points {
|
||||
r.Train(regression.DataPoint(p.y, []float64{p.x}))
|
||||
}
|
||||
|
||||
if err := r.Run(); err != nil {
|
||||
log.Printf("线性回归计算失败: %v", err)
|
||||
return models.RegressionResult{
|
||||
Equation: "计算失败",
|
||||
}
|
||||
}
|
||||
|
||||
// 创建结果
|
||||
slope := r.Coeff(1)
|
||||
intercept := r.Coeff(0)
|
||||
r2 := r.R2
|
||||
return models.RegressionResult{
|
||||
RegressionType: models.LinearRegression,
|
||||
Slope: &slope,
|
||||
Intercept: &intercept,
|
||||
RSquared: &r2,
|
||||
Equation: r.Formula,
|
||||
}
|
||||
}
|
||||
|
||||
// 实现对数和二次回归算法
|
||||
// 对数回归算法
|
||||
func performLogarithmicRegression(averages []map[float64]float64) models.RegressionResult {
|
||||
if len(averages) == 0 {
|
||||
return models.RegressionResult{
|
||||
Equation: "无数据",
|
||||
}
|
||||
}
|
||||
|
||||
// 收集数据点
|
||||
r := new(regression.Regression)
|
||||
r.SetObserved("y")
|
||||
r.SetVar(0, "log(x+1)")
|
||||
|
||||
for _, m := range averages {
|
||||
for speed, hr := range m {
|
||||
logSpeed := math.Log(speed + 1)
|
||||
r.Train(regression.DataPoint(hr, []float64{logSpeed}))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.Run(); err != nil {
|
||||
log.Printf("对数回归计算失败: %v", err)
|
||||
return models.RegressionResult{
|
||||
Equation: "计算失败",
|
||||
}
|
||||
}
|
||||
|
||||
// 创建结果
|
||||
logA := r.Coeff(1)
|
||||
logB := r.Coeff(0)
|
||||
r2 := r.R2
|
||||
return models.RegressionResult{
|
||||
RegressionType: models.LogarithmicRegression,
|
||||
LogA: &logA,
|
||||
LogB: &logB,
|
||||
RSquared: &r2,
|
||||
Equation: r.Formula,
|
||||
}
|
||||
}
|
||||
|
||||
// 二次回归算法
|
||||
//func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
|
||||
// if len(averages) == 0 {
|
||||
// return models.RegressionResult{
|
||||
// Equation: "无数据",
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 收集数据点
|
||||
// r := new(regression.Regression)
|
||||
// r.SetObserved("y")
|
||||
// r.SetVar(0, "x")
|
||||
// r.SetVar(1, "x²")
|
||||
//
|
||||
// for _, m := range averages {
|
||||
// for speed, hr := range m {
|
||||
// speedSq := math.Pow(speed, 2)
|
||||
// r.Train(regression.DataPoint(hr, []float64{speed, speedSq}))
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if err := r.Run(); err != nil {
|
||||
// log.Printf("二次回归计算失败: %v", err)
|
||||
// return models.RegressionResult{
|
||||
// Equation: "计算失败",
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 创建结果
|
||||
// a := r.Coeff(2)
|
||||
// b := r.Coeff(1)
|
||||
// c := r.Coeff(0)
|
||||
// r2 := r.R2
|
||||
// return models.RegressionResult{
|
||||
// RegressionType: models.QuadraticRegression,
|
||||
// QuadraticA: &a,
|
||||
// QuadraticB: &b,
|
||||
// QuadraticC: &c,
|
||||
// RSquared: &r2,
|
||||
// Equation: r.Formula,
|
||||
// }
|
||||
//}
|
||||
|
||||
func performQuadraticRegression(averages []map[float64]float64) models.RegressionResult {
|
||||
if len(averages) == 0 {
|
||||
return models.RegressionResult{
|
||||
Equation: "无数据",
|
||||
}
|
||||
}
|
||||
|
||||
// 步骤1:收集所有数据点(与Flutter一致)
|
||||
var xValues []float64
|
||||
var yValues []float64
|
||||
for _, m := range averages {
|
||||
for speed, hr := range m {
|
||||
xValues = append(xValues, speed)
|
||||
yValues = append(yValues, hr)
|
||||
}
|
||||
}
|
||||
n := float64(len(xValues))
|
||||
|
||||
// 步骤2:计算各项和(完全匹配Flutter的计算)
|
||||
var sumX, sumY, sumX2, sumX3, sumX4, sumXY, sumX2Y float64
|
||||
for i := 0; i < len(xValues); i++ {
|
||||
x := xValues[i]
|
||||
y := yValues[i]
|
||||
x2 := x * x
|
||||
x3 := x2 * x
|
||||
x4 := x3 * x
|
||||
|
||||
sumX += x
|
||||
sumY += y
|
||||
sumX2 += x2
|
||||
sumX3 += x3
|
||||
sumX4 += x4
|
||||
sumXY += x * y
|
||||
sumX2Y += x2 * y
|
||||
}
|
||||
|
||||
// 步骤3:构建正规方程矩阵(与Flutter完全一致)
|
||||
matrix := [3][3]float64{
|
||||
{n, sumX, sumX2},
|
||||
{sumX, sumX2, sumX3},
|
||||
{sumX2, sumX3, sumX4},
|
||||
}
|
||||
vector := []float64{sumY, sumXY, sumX2Y}
|
||||
|
||||
// 步骤4:计算矩阵行列式(复制Flutter的determinant3x3逻辑)
|
||||
det := matrix[0][0]*(matrix[1][1]*matrix[2][2]-matrix[1][2]*matrix[2][1]) -
|
||||
matrix[0][1]*(matrix[1][0]*matrix[2][2]-matrix[1][2]*matrix[2][0]) +
|
||||
matrix[0][2]*(matrix[1][0]*matrix[2][1]-matrix[1][1]*matrix[2][0])
|
||||
|
||||
if det == 0 {
|
||||
return models.RegressionResult{
|
||||
Equation: "无法拟合",
|
||||
}
|
||||
}
|
||||
|
||||
// 步骤5:克莱姆法则求解系数(顺序与Flutter一致)
|
||||
// 注意:最终系数顺序 a=二次项, b=一次项, c=常数项
|
||||
c := det3x3([3][3]float64{
|
||||
{vector[0], matrix[0][1], matrix[0][2]},
|
||||
{vector[1], matrix[1][1], matrix[1][2]},
|
||||
{vector[2], matrix[2][1], matrix[2][2]},
|
||||
}) / det
|
||||
|
||||
b := det3x3([3][3]float64{
|
||||
{matrix[0][0], vector[0], matrix[0][2]},
|
||||
{matrix[1][0], vector[1], matrix[1][2]},
|
||||
{matrix[2][0], vector[2], matrix[2][2]},
|
||||
}) / det
|
||||
|
||||
a := det3x3([3][3]float64{
|
||||
{matrix[0][0], matrix[0][1], vector[0]},
|
||||
{matrix[1][0], matrix[1][1], vector[1]},
|
||||
{matrix[2][0], matrix[2][1], vector[2]},
|
||||
}) / det
|
||||
|
||||
// 步骤6:计算R平方(完全复制Flutter的计算逻辑)
|
||||
var ssRes, ssTot float64
|
||||
meanY := sumY / n
|
||||
for i := 0; i < len(xValues); i++ {
|
||||
x := xValues[i]
|
||||
y := yValues[i]
|
||||
yPred := a*x*x + b*x + c
|
||||
ssRes += math.Pow(y-yPred, 2)
|
||||
ssTot += math.Pow(y-meanY, 2)
|
||||
}
|
||||
rSquared := 0.0
|
||||
if ssTot != 0 {
|
||||
rSquared = 1 - ssRes/ssTot
|
||||
}
|
||||
|
||||
// 步骤7:格式化公式字符串(与Flutter格式完全一致)
|
||||
equation := formatEquation(a, b, c, rSquared)
|
||||
|
||||
return models.RegressionResult{
|
||||
RegressionType: models.QuadraticRegression,
|
||||
QuadraticA: &a,
|
||||
QuadraticB: &b,
|
||||
QuadraticC: &c,
|
||||
RSquared: &rSquared,
|
||||
Equation: equation,
|
||||
}
|
||||
}
|
||||
|
||||
// 3x3行列式计算(与Flutter实现相同)
|
||||
func det3x3(m [3][3]float64) float64 {
|
||||
return m[0][0]*(m[1][1]*m[2][2]-m[1][2]*m[2][1]) -
|
||||
m[0][1]*(m[1][0]*m[2][2]-m[1][2]*m[2][0]) +
|
||||
m[0][2]*(m[1][0]*m[2][1]-m[1][1]*m[2][0])
|
||||
}
|
||||
|
||||
// 公式格式化(完全匹配Flutter格式)
|
||||
func formatEquation(a, b, c, r2 float64) string {
|
||||
// 保留4位小数
|
||||
aStr := fmt.Sprintf("%.4f", a)
|
||||
bStr := fmt.Sprintf("%.4f", b)
|
||||
cStr := fmt.Sprintf("%.4f", c)
|
||||
r2Str := fmt.Sprintf("%.4f", r2)
|
||||
|
||||
builder := strings.Builder{}
|
||||
builder.WriteString("y = ")
|
||||
|
||||
// 处理二次项
|
||||
if a >= 0 {
|
||||
builder.WriteString(aStr + " x²")
|
||||
} else {
|
||||
builder.WriteString("-" + strings.TrimPrefix(aStr, "-") + " x²")
|
||||
}
|
||||
|
||||
// 处理一次项
|
||||
if b >= 0 {
|
||||
builder.WriteString(" + " + bStr + " x")
|
||||
} else {
|
||||
builder.WriteString(" - " + strings.TrimPrefix(bStr, "-") + " x")
|
||||
}
|
||||
|
||||
// 处理常数项
|
||||
if c >= 0 {
|
||||
builder.WriteString(" + " + cStr)
|
||||
} else {
|
||||
builder.WriteString(" - " + strings.TrimPrefix(cStr, "-"))
|
||||
}
|
||||
|
||||
builder.WriteString(" (R² = " + r2Str + ")")
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// 步频数据转换为速度段
|
||||
func convertStrideFrequencyToSegments(steps []models.StepStrideFreq) []SpeedSegment {
|
||||
if len(steps) == 0 {
|
||||
return []SpeedSegment{}
|
||||
}
|
||||
|
||||
// 过滤零值并排序
|
||||
validSteps := make([]models.StepStrideFreq, 0, len(steps))
|
||||
for _, s := range steps {
|
||||
if s.Value > 0 {
|
||||
validSteps = append(validSteps, s)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validSteps) == 0 {
|
||||
return []SpeedSegment{}
|
||||
}
|
||||
|
||||
// 按时间排序
|
||||
for i := 0; i < len(validSteps)-1; i++ {
|
||||
for j := i + 1; j < len(validSteps); j++ {
|
||||
if validSteps[i].Time > validSteps[j].Time {
|
||||
validSteps[i], validSteps[j] = validSteps[j], validSteps[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建速度段
|
||||
segments := make([]SpeedSegment, 0)
|
||||
startTime := validSteps[0].Time
|
||||
currentValue := validSteps[0].Value
|
||||
|
||||
for i := 1; i < len(validSteps); i++ {
|
||||
if validSteps[i].Value != currentValue {
|
||||
duration := float64(validSteps[i].Time-startTime) / 1000.0
|
||||
if duration > 0 {
|
||||
segments = append(segments, SpeedSegment{
|
||||
Duration: duration,
|
||||
Speed: float64(currentValue),
|
||||
})
|
||||
}
|
||||
startTime = validSteps[i].Time
|
||||
currentValue = validSteps[i].Value
|
||||
}
|
||||
}
|
||||
|
||||
// 添加最后一个段
|
||||
if len(validSteps) > 0 {
|
||||
duration := float64(validSteps[len(validSteps)-1].Time-startTime) / 1000.0
|
||||
if duration > 0 {
|
||||
segments = append(segments, SpeedSegment{
|
||||
Duration: duration,
|
||||
Speed: float64(currentValue),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
// 计算区段平均值
|
||||
func calculateSegmentAverages(heartRates []models.StepHeartRate, segments []SpeedSegment, errorThreshold int) []map[float64]float64 {
|
||||
currentTime := 0.0
|
||||
results := make([]map[float64]float64, 0)
|
||||
|
||||
for _, seg := range segments {
|
||||
minRequired := 60 + (60 - float64(errorThreshold))
|
||||
|
||||
// 跳过不满足条件的区段
|
||||
if seg.Duration < minRequired {
|
||||
currentTime += seg.Duration
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算时间窗口
|
||||
startSec := currentTime + 60
|
||||
endSec := currentTime
|
||||
if seg.Duration >= 120 {
|
||||
endSec = currentTime + 120
|
||||
} else {
|
||||
endSec = currentTime + 120 - float64(errorThreshold)
|
||||
}
|
||||
|
||||
// 收集该区段的心率数据
|
||||
sum, count := 0, 0
|
||||
for _, hr := range heartRates {
|
||||
sec := float64(hr.Time) / 1000.0
|
||||
if sec >= startSec && sec <= endSec {
|
||||
sum += hr.Value
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
// 计算平均值
|
||||
if count > 0 {
|
||||
avg := float64(sum) / float64(count)
|
||||
results = append(results, map[float64]float64{seg.Speed: avg})
|
||||
}
|
||||
|
||||
currentTime += seg.Duration
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// 计算步频区段的心率平均值
|
||||
func CalculateSegmentAveragesByRealStep(heartRates []models.StepHeartRate, steps []models.StepStrideFreq) []map[float64]float64 {
|
||||
segments := convertStrideFrequencyToSegments(steps)
|
||||
return calculateSegmentAverages(heartRates, segments, 15) // 默认5秒误差阈值
|
||||
}
|
||||
|
||||
// 存储回归结果到数据库(支持多种回归类型)
|
||||
func (tc *StepTrainingController) SaveRegressionResults(trainId uint, results []models.RegressionResult) error {
|
||||
return tc.DB.Transaction(func(tx *gorm.DB) error {
|
||||
for i := range results {
|
||||
results[i].TrainId = trainId
|
||||
// 使用复合唯一约束确保每种回归类型只存储一条记录
|
||||
err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "id"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"equation": results[i].Equation,
|
||||
"slope": results[i].Slope,
|
||||
"intercept": results[i].Intercept,
|
||||
"log_a": results[i].LogA,
|
||||
"log_b": results[i].LogB,
|
||||
"quadratic_a": results[i].QuadraticA,
|
||||
"quadratic_b": results[i].QuadraticB,
|
||||
"quadratic_c": results[i].QuadraticC,
|
||||
"r_squared": results[i].RSquared,
|
||||
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
}),
|
||||
}).Create(&results[i]).Error
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// 获取或计算回归结果(返回多种回归类型列表)
|
||||
func (tc *StepTrainingController) GetOrCalculateRegression(trainId uint) ([]models.RegressionResult, error) {
|
||||
// 尝试从数据库获取所有类型的回归结果
|
||||
var results []models.RegressionResult
|
||||
err := tc.DB.Where("train_id = ?", trainId).Find(&results).Error
|
||||
|
||||
// 如果已存在三种类型的结果,直接返回
|
||||
if err == nil && len(results) >= 3 {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// 查询训练记录及相关数据
|
||||
var record models.StepTrainRecord
|
||||
if err := tc.DB.
|
||||
Where("train_id = ?", uint(trainId)).
|
||||
Preload("HeartRates", "heart_rate_type = ?", 1).
|
||||
Preload("StrideFreqs", "predict_value = ?", 1).
|
||||
First(&record).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 计算心率平均值
|
||||
averages := CalculateSegmentAveragesByRealStep(record.HeartRates, record.StrideFreqs)
|
||||
if len(averages) == 0 {
|
||||
return nil, errors.New("无足够数据进行回归计算")
|
||||
}
|
||||
|
||||
// 创建三种回归类型的结果
|
||||
results = make([]models.RegressionResult, 3)
|
||||
|
||||
// 线性回归
|
||||
linearRes := performLinearRegression(averages)
|
||||
results[0] = models.RegressionResult{
|
||||
RegressionType: models.LinearRegression,
|
||||
TrainId: trainId,
|
||||
Equation: linearRes.Equation,
|
||||
Slope: linearRes.Slope,
|
||||
Intercept: linearRes.Intercept,
|
||||
RSquared: linearRes.RSquared,
|
||||
}
|
||||
|
||||
// 对数回归
|
||||
logRes := performLogarithmicRegression(averages)
|
||||
results[1] = models.RegressionResult{
|
||||
RegressionType: models.LogarithmicRegression,
|
||||
TrainId: trainId,
|
||||
Equation: logRes.Equation,
|
||||
LogA: logRes.LogA,
|
||||
LogB: logRes.LogB,
|
||||
RSquared: logRes.RSquared,
|
||||
}
|
||||
|
||||
// 二次回归
|
||||
quadRes := performQuadraticRegression(averages)
|
||||
results[2] = models.RegressionResult{
|
||||
RegressionType: models.QuadraticRegression,
|
||||
TrainId: trainId,
|
||||
Equation: quadRes.Equation,
|
||||
QuadraticA: quadRes.QuadraticA,
|
||||
QuadraticB: quadRes.QuadraticB,
|
||||
QuadraticC: quadRes.QuadraticC,
|
||||
RSquared: quadRes.RSquared,
|
||||
}
|
||||
|
||||
// 批量保存结果到数据库
|
||||
if err := tc.SaveRegressionResults(trainId, results); err != nil {
|
||||
log.Printf("保存回归结果失败: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// 新增接口:获取回归结果
|
||||
func (tc *StepTrainingController) GetRegressionResult(c *gin.Context) {
|
||||
trainIdStr := c.Param("trainId")
|
||||
tid, err := strconv.ParseUint(trainIdStr, 10, 32)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
||||
return
|
||||
}
|
||||
|
||||
result, err := tc.GetOrCalculateRegression(uint(tid))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "获取成功",
|
||||
"data": result,
|
||||
})
|
||||
}
|
||||
|
||||
func (tc *StepTrainingController) GetTrainingRank(c *gin.Context) {
|
||||
// 参数解析
|
||||
trainIdStr := c.Param("trainId")
|
||||
regressionTypeStr := c.Query("type")
|
||||
regressionType, err := strconv.Atoi(regressionTypeStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "参数type必须为整数"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证回归类型
|
||||
regType := models.RegressionType(regressionType)
|
||||
if regType != models.LinearRegression && regType != models.QuadraticRegression {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的回归类型,必须是'linear'或'quadratic'"})
|
||||
return
|
||||
}
|
||||
|
||||
// 转换训练ID
|
||||
tid, err := strconv.ParseUint(trainIdStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的训练ID"})
|
||||
return
|
||||
}
|
||||
trainId := uint(tid)
|
||||
|
||||
// 确保回归结果存在
|
||||
if _, err := tc.GetOrCalculateRegression(trainId); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取回归结果失败:" + err.Error()})
|
||||
return
|
||||
}
|
||||
// 获取指定训练的基准值
|
||||
var baseValue float64
|
||||
baseQuery := tc.DB.Model(&models.RegressionResult{}).
|
||||
Select(getValueColumn(regType)).
|
||||
Where("train_id = ?", trainId)
|
||||
|
||||
if err := baseQuery.Row().Scan(&baseValue); err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "指定的训练数据不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取基准数据失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 在数据库中进行排名计算
|
||||
var rank struct {
|
||||
BetterCount int64
|
||||
Total int64
|
||||
}
|
||||
|
||||
// 动态生成比较条件
|
||||
betterCondition := fmt.Sprintf("%s %s ?",
|
||||
getValueColumn(regType),
|
||||
getComparisonOperator(regType))
|
||||
|
||||
totalQuery := tc.DB.Model(&models.RegressionResult{}).
|
||||
Where(getTypeCondition(regType))
|
||||
|
||||
if err := totalQuery.Count(&rank.Total).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "统计总数失败"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.DB.Model(&models.RegressionResult{}).
|
||||
Where(getTypeCondition(regType)).
|
||||
Where(betterCondition, baseValue).
|
||||
Count(&rank.BetterCount).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "计算排名失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 计算实际排名 (并列排名)
|
||||
currentRank := rank.BetterCount + 1
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "排名查询成功",
|
||||
"data": gin.H{
|
||||
"trainId": trainId,
|
||||
"type": regressionType,
|
||||
"rank": currentRank,
|
||||
"total": rank.Total,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// 辅助函数:获取排序字段名
|
||||
func getValueColumn(regType models.RegressionType) string {
|
||||
switch regType {
|
||||
case models.LinearRegression:
|
||||
return "slope"
|
||||
case models.QuadraticRegression:
|
||||
return "ABS(quadratic_a)" // 计算绝对值
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:获取比较操作符
|
||||
func getComparisonOperator(regType models.RegressionType) string {
|
||||
switch regType {
|
||||
case models.LinearRegression:
|
||||
return "<" // 线性回归:值越小越好
|
||||
case models.QuadraticRegression:
|
||||
return ">" // 二次回归:绝对值越大越好
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:获取类型条件
|
||||
func getTypeCondition(regType models.RegressionType) string {
|
||||
switch regType {
|
||||
case models.LinearRegression:
|
||||
return "slope IS NOT NULL"
|
||||
case models.QuadraticRegression:
|
||||
return "quadratic_a IS NOT NULL AND quadratic_a < 0" // 确保是负值
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:比较浮点指针(用于线性回归)
|
||||
func compareFloatPtr(a, b *float64, ascending bool) bool {
|
||||
if a == nil && b == nil {
|
||||
return false
|
||||
}
|
||||
if a == nil {
|
||||
return false // 空值排最后
|
||||
}
|
||||
if b == nil {
|
||||
return true // 非空值排前
|
||||
}
|
||||
|
||||
if ascending {
|
||||
return *a < *b
|
||||
}
|
||||
return *a > *b
|
||||
}
|
||||
|
||||
// 辅助函数:比较二次项系数(用于二次回归)
|
||||
func compareQuadraticA(a, b *float64) bool {
|
||||
if a == nil && b == nil {
|
||||
return false
|
||||
}
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
if b == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// 比较绝对值(a和b都是负值,所以取绝对值后大的排前面)
|
||||
return math.Abs(*a) > math.Abs(*b)
|
||||
}
|
||||
+44
-8
@@ -212,7 +212,12 @@ func convertToCurvePoints(x, y []float64) []CurvePoint {
|
||||
}
|
||||
|
||||
func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainRecord) error {
|
||||
startTime := record.StartTime
|
||||
var startTime int64
|
||||
if record.TestTime > 0 {
|
||||
startTime = record.TestTime
|
||||
} else {
|
||||
startTime = record.StartTime
|
||||
}
|
||||
|
||||
// 获取所有唯一的beltID
|
||||
var beltIDs []uint
|
||||
@@ -231,7 +236,7 @@ func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainR
|
||||
// 曲线拟合
|
||||
x := []float64{2, 4, 6}
|
||||
y := []float64{averages["2min"], averages["4min"], averages["6min"]}
|
||||
a, _ := quadraticFit(x, y)
|
||||
a, b, _ := quadraticFit(x, y)
|
||||
|
||||
// 存储结果
|
||||
analysis := models.BeltAnalysis{
|
||||
@@ -242,6 +247,7 @@ func (tc *TrainingController) heartRateAnalyze(tx *gorm.DB, record models.TrainR
|
||||
Avg4min: averages["4min"],
|
||||
Avg6min: averages["6min"],
|
||||
CurveParamA: a,
|
||||
CurveParamB: b,
|
||||
}
|
||||
if err := tx.Create(&analysis).Error; err != nil {
|
||||
return err
|
||||
@@ -305,14 +311,44 @@ func calculateAverages(tx *gorm.DB, trainID uint, beltID uint, ranges map[string
|
||||
return averages, nil
|
||||
}
|
||||
|
||||
func quadraticFit(x []float64, y []float64) (float64, error) {
|
||||
// 使用三点计算y=ax²+b的a值(x=[2,4,6]对应分钟)
|
||||
//func quadraticFit(x []float64, y []float64) (float64, error) {
|
||||
// // 使用三点计算y=ax²+b的a值(x=[2,4,6]对应分钟)
|
||||
// if len(x) != 3 || len(y) != 3 {
|
||||
// return 0, errors.New("需要三个点")
|
||||
// }
|
||||
// // 构造方程组矩阵(简化计算)
|
||||
// a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
|
||||
// return a, nil
|
||||
//}
|
||||
|
||||
func quadraticFit(x []float64, y []float64) (float64, float64, error) {
|
||||
// 校验输入长度
|
||||
if len(x) != 3 || len(y) != 3 {
|
||||
return 0, errors.New("需要三个点")
|
||||
return 0, 0, errors.New("需要三个点")
|
||||
}
|
||||
// 构造方程组矩阵(简化计算)
|
||||
a := (y[2] - 2*y[1] + y[0]) / (x[2]*x[2] - 2*x[1]*x[1] + x[0]*x[0])
|
||||
return a, nil
|
||||
|
||||
// 计算各项累加值
|
||||
var sumX4, sumX2, sumY, sumX2Y float64
|
||||
for i := 0; i < 3; i++ {
|
||||
xi := x[i]
|
||||
xi2 := xi * xi
|
||||
sumX4 += xi2 * xi2 // x^4累加
|
||||
sumX2 += xi2 // x^2累加
|
||||
sumY += y[i] // y累加
|
||||
sumX2Y += xi2 * y[i] // x²y累加
|
||||
}
|
||||
|
||||
// 计算行列式
|
||||
determinant := sumX4*3 - sumX2*sumX2
|
||||
if determinant == 0 {
|
||||
return 0, 0, errors.New("无解,行列式为零")
|
||||
}
|
||||
|
||||
// 计算系数 a 和 b
|
||||
a := (sumX2Y*3 - sumY*sumX2) / determinant
|
||||
b := (sumX4*sumY - sumX2*sumX2Y) / determinant
|
||||
|
||||
return a, b, nil
|
||||
}
|
||||
|
||||
type TimeRange struct {
|
||||
|
||||
+7951
File diff suppressed because it is too large
Load Diff
+1
-1
@@ -4,7 +4,7 @@ services:
|
||||
app:
|
||||
build: .
|
||||
ports:
|
||||
- "8180:8080"
|
||||
- "127.0.0.1:8180:8080"
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -3,8 +3,11 @@ module hr_receiver
|
||||
go 1.23.3
|
||||
|
||||
require (
|
||||
github.com/fumiama/go-docx v0.0.0-20250506085032-0c30fd09304b
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/sajari/regression v1.0.1
|
||||
github.com/sashabaranov/go-openai v1.41.2
|
||||
github.com/spf13/viper v1.20.0
|
||||
gonum.org/v1/gonum v0.16.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
@@ -17,6 +20,7 @@ require (
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||
github.com/fumiama/imgsz v0.0.2 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
|
||||
@@ -13,6 +13,10 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/fumiama/go-docx v0.0.0-20250506085032-0c30fd09304b h1:/mxSugRc4SgN7XgBtT19dAJ7cAXLTbPmlJLJE4JjRkE=
|
||||
github.com/fumiama/go-docx v0.0.0-20250506085032-0c30fd09304b/go.mod h1:ssRF0IaB1hCcKIObp3FkZOsjTcAHpgii70JelNb4H8M=
|
||||
github.com/fumiama/imgsz v0.0.2 h1:fAkC0FnIscdKOXwAxlyw3EUba5NzxZdSxGaq3Uyfxak=
|
||||
github.com/fumiama/imgsz v0.0.2/go.mod h1:dR71mI3I2O5u6+PCpd47M9TZptzP+39tRBcbdIkoqM4=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
@@ -75,6 +79,10 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
|
||||
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
|
||||
github.com/sajari/regression v1.0.1 h1:iTVc6ZACGCkoXC+8NdqH5tIreslDTT/bXxT6OmHR5PE=
|
||||
github.com/sajari/regression v1.0.1/go.mod h1:NeG/XTW1lYfGY7YV/Z0nYDV/RGh3wxwd1yW46835flM=
|
||||
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
|
||||
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
|
||||
|
||||
@@ -15,7 +15,17 @@ func main() {
|
||||
|
||||
config.DB.Debug()
|
||||
// 自动迁移模型
|
||||
config.DB.AutoMigrate(&models.TrainRecord{}, &models.TrainingData{}, &models.Belt{}, &models.HeartRate{}, &models.BeltAnalysis{})
|
||||
config.DB.AutoMigrate(&models.TrainRecord{},
|
||||
&models.TrainingData{},
|
||||
&models.Belt{},
|
||||
&models.HeartRate{},
|
||||
&models.BeltAnalysis{},
|
||||
&models.StepTrainRecord{},
|
||||
&models.StepHeartRate{},
|
||||
&models.StepStrideFreq{},
|
||||
&models.RegressionResult{},
|
||||
&models.User{},
|
||||
)
|
||||
|
||||
// 启动服务
|
||||
r := routes.SetupRouter()
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"compress/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GzipMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 1. 处理请求解压
|
||||
if c.Request.Header.Get("Content-Encoding") == "gzip" {
|
||||
gzReader, err := gzip.NewReader(c.Request.Body)
|
||||
if err != nil {
|
||||
@@ -17,6 +19,39 @@ func GzipMiddleware() gin.HandlerFunc {
|
||||
defer gzReader.Close()
|
||||
c.Request.Body = gzReader
|
||||
}
|
||||
|
||||
// 2. 设置响应压缩支持
|
||||
if strings.Contains(c.Request.Header.Get("Accept-Encoding"), "gzip") {
|
||||
// 创建gzip writer
|
||||
gzWriter := gzip.NewWriter(c.Writer)
|
||||
defer gzWriter.Close()
|
||||
|
||||
// 替换原始writer为压缩writer
|
||||
originalWriter := c.Writer
|
||||
c.Writer = &gzipResponseWriter{
|
||||
ResponseWriter: originalWriter,
|
||||
gzWriter: gzWriter,
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Content-Encoding", "gzip")
|
||||
c.Header("Vary", "Accept-Encoding")
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义ResponseWriter实现gzip压缩
|
||||
type gzipResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
gzWriter *gzip.Writer
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) Write(data []byte) (int, error) {
|
||||
return w.gzWriter.Write(data)
|
||||
}
|
||||
|
||||
func (w *gzipResponseWriter) WriteString(s string) (int, error) {
|
||||
return w.gzWriter.Write([]byte(s))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"hr_receiver/util"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func JWTAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Bearer Token格式
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if !(len(parts) == 2 && parts[0] == "Bearer") {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析Token
|
||||
claims, err := util.ParseToken(parts[1])
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存入上下文
|
||||
c.Set("userID", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
package models
|
||||
@@ -0,0 +1,70 @@
|
||||
package models
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type StepStrideFreq struct {
|
||||
gorm.Model
|
||||
TrainId uint `gorm:"column:train_id; index" json:"trainId"` // 外键关联训练记录[4](@ref)
|
||||
Time int64 `gorm:"type:bigint" json:"time"` // 保持与前端一致的毫秒时间戳[3](@ref)
|
||||
Value int `gorm:"type:int" json:"value"`
|
||||
Count int `gorm:"type:int" json:"count"`
|
||||
PredictValue int `gorm:"type:int" json:"predictValue"`
|
||||
Identifier string `gorm:"uniqueIndex;type:varchar(255)" json:"identifier"`
|
||||
}
|
||||
|
||||
// 对应Flutter的HeartRate结构
|
||||
type StepHeartRate struct {
|
||||
gorm.Model
|
||||
TrainId uint `gorm:"column:train_id; index" json:"trainId"` // 外键关联训练记录[4](@ref)
|
||||
Time int64 `gorm:"type:bigint" json:"time"` // 保持与前端一致的毫秒时间戳[3](@ref)
|
||||
Value int `gorm:"type:int" json:"value"`
|
||||
HeartRateType int `gorm:"type:int" json:"predictValue"`
|
||||
Identifier string `gorm:"uniqueIndex;type:varchar(255)" json:"identifier"`
|
||||
}
|
||||
|
||||
// 对应Flutter的TrainRecord结构
|
||||
type StepTrainRecord struct {
|
||||
gorm.Model
|
||||
Username string `gorm:"size:50" json:"username"` // 对应Dart的tid字段
|
||||
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
|
||||
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
|
||||
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
|
||||
Name string `gorm:"size:100" json:"name"`
|
||||
RunType string `gorm:"size:100" json:"runType"`
|
||||
MaxHeartRate int `gorm:"type:int" json:"maxHeartRate"`
|
||||
Duration int `gorm:"type:int" json:"duration"` // 持续时间(秒)
|
||||
DeadZone int `gorm:"type:int" json:"deadZone"`
|
||||
Evaluation string `gorm:"size:50" json:"evaluation"`
|
||||
HeartRates []StepHeartRate `gorm:"foreignKey:TrainId;references:TrainId" json:"heartRates"`
|
||||
StrideFreqs []StepStrideFreq `gorm:"foreignKey:TrainId;references:TrainId" json:"strideFreqs"`
|
||||
}
|
||||
|
||||
type RegressionType int
|
||||
|
||||
const (
|
||||
LinearRegression RegressionType = iota + 1
|
||||
LogarithmicRegression
|
||||
QuadraticRegression
|
||||
)
|
||||
|
||||
type RegressionResult struct {
|
||||
gorm.Model
|
||||
RegressionType RegressionType `gorm:"column:regression_type;index" json:"regressionType"` // 训练记录ID
|
||||
TrainId uint `gorm:"column:train_id;index" json:"trainId"` // 训练记录ID
|
||||
Equation string `gorm:"type:text" json:"equation"` // 回归方程
|
||||
|
||||
// 线性回归系数
|
||||
Slope *float64 `gorm:"column:slope" json:"slope"`
|
||||
Intercept *float64 `gorm:"column:intercept" json:"intercept"`
|
||||
|
||||
// 对数回归系数
|
||||
LogA *float64 `gorm:"column:log_a" json:"logA"`
|
||||
LogB *float64 `gorm:"column:log_b" json:"logB"`
|
||||
|
||||
// 二次回归系数
|
||||
QuadraticA *float64 `gorm:"column:quadratic_a" json:"quadraticA"`
|
||||
QuadraticB *float64 `gorm:"column:quadratic_b" json:"quadraticB"`
|
||||
QuadraticC *float64 `gorm:"column:quadratic_c" json:"quadraticC"`
|
||||
|
||||
RSquared *float64 `gorm:"column:r_squared" json:"rSquared"` // R平方值
|
||||
}
|
||||
+5
-3
@@ -28,6 +28,7 @@ type BeltAnalysis struct {
|
||||
Avg4min float64 `gorm:"type:double precision"` // 第4分钟平均心率
|
||||
Avg6min float64 `gorm:"type:double precision"` // 第6分钟平均心率
|
||||
CurveParamA float64 `gorm:"type:double precision"` // 拟合参数a值
|
||||
CurveParamB float64 `gorm:"type:double precision"` // 拟合参数a值
|
||||
}
|
||||
|
||||
// 中间计算结构(无需持久化)
|
||||
@@ -47,9 +48,10 @@ type HeartRate struct {
|
||||
// 对应Flutter的TrainRecord结构
|
||||
type TrainRecord struct {
|
||||
gorm.Model
|
||||
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
|
||||
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
|
||||
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
|
||||
TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段
|
||||
StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳
|
||||
TestTime int64 `gorm:"type:bigint" json:"testTime"` // 开始时间戳
|
||||
EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref)
|
||||
Name string `gorm:"size:100" json:"name"`
|
||||
RunType string `gorm:"size:100" json:"RunType"`
|
||||
MaxHeartRate int `gorm:"type:int" json:"maxHeartRate"`
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Username string `gorm:"uniqueIndex;not null" json:"username"`
|
||||
Email *string `gorm:"uniqueIndex;" json:"email"`
|
||||
Phone *string `gorm:"uniqueIndex;" json:"phone"`
|
||||
Password string `gorm:"not null" json:"-"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// HashPassword 密码加密
|
||||
func (u *User) HashPassword(password string) error {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Password = string(bytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPassword 验证密码
|
||||
func (u *User) CheckPassword(password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前钩子
|
||||
func (u *User) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if u.Password != "" {
|
||||
return u.HashPassword(u.Password)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -12,6 +12,7 @@ func SetupRouter() *gin.Engine {
|
||||
r := gin.Default()
|
||||
r.Use(middleware.GzipMiddleware())
|
||||
trainingController := controllers.NewTrainingController()
|
||||
stepTrainController := controllers.NewStepTrainingController()
|
||||
|
||||
v1 := r.Group("/api/v1")
|
||||
{
|
||||
@@ -19,8 +20,22 @@ func SetupRouter() *gin.Engine {
|
||||
{
|
||||
records.POST("", trainingController.CreateTrainingRecord)
|
||||
records.GET("/analysis", trainingController.HandleCurveAnalysis)
|
||||
records.POST("/analysis-by-ai", trainingController.AnalyzeByAI)
|
||||
// 可扩展其他路由:GET, PUT, DELETE等
|
||||
}
|
||||
steps := v1.Group("/step").Use(middleware.JWTAuth())
|
||||
{
|
||||
steps.POST("", stepTrainController.CreateTrainingRecord)
|
||||
steps.GET("train-records", stepTrainController.GetTrainingRecords)
|
||||
steps.GET("train-data/:trainId", stepTrainController.GetTrainingRecordByTrainId)
|
||||
steps.GET("train-rank/:trainId", stepTrainController.GetTrainingRank)
|
||||
// 可扩展其他路由:GET, PUT, DELETE等
|
||||
}
|
||||
public := v1.Group("")
|
||||
{
|
||||
public.POST("/register", controllers.Register)
|
||||
public.POST("/login", controllers.Login)
|
||||
}
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.GET("/token", func(c *gin.Context) {
|
||||
|
||||
+188
@@ -0,0 +1,188 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// 配置文件
|
||||
const (
|
||||
BaseURL = "https://api.lkeap.cloud.tencent.com/v1"
|
||||
APIKey = "sk-Y4zjnwulSuSlf60mrzwCxq2ipktHSs4jZHgWeQOArWuWJEOd" // 替换为实际的API Key
|
||||
Model = "deepseek-v3" // 推荐使用terminus版本
|
||||
)
|
||||
|
||||
// 读取文件内容
|
||||
func readFileContent(filename string) (string, error) {
|
||||
content, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取文件 %s 失败: %v", filename, err)
|
||||
}
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
// 构建分析提示词
|
||||
func buildAnalysisPrompt(teachingPlanContent, heartRateContent string) string {
|
||||
return fmt.Sprintf(`请根据以下体育课堂的教案和心率监测数据,生成一份详细的课堂分析报告:
|
||||
|
||||
## 教案内容:
|
||||
%s
|
||||
|
||||
## 心率监测数据:
|
||||
%s
|
||||
|
||||
这是一份幼儿园体育课的教案和课程心率监测数据,请帮对照分析课程教学效果,运动量和运动负荷情况是否科学,并提出课程设计的优化方案。
|
||||
优化方案参考如下格式,教学过程需要详细一些:
|
||||
# 幼儿体育教案(华侨大学版本)
|
||||
|
||||
| 项目 | 内容 |
|
||||
| ------------ | -------------------------------- |
|
||||
| **课程名** | |
|
||||
| **年段** | 小 中 大 |
|
||||
| **教师姓名** | |
|
||||
| **时间** | 年 月 日 |
|
||||
| **地点** | |
|
||||
| **人数** | 男: 女: |
|
||||
| **时长** | 分钟 |
|
||||
| **天气预报** | 晴 雨 阴;温度 ℃ |
|
||||
| **器材准备** | |
|
||||
|
||||
## 教学目标
|
||||
|
||||
| 类型 | 目标 |
|
||||
| -------- | ------------ |
|
||||
| **体能目标** | |
|
||||
| **技能目标** | |
|
||||
| **情感目标** | |
|
||||
|
||||
## 教学过程
|
||||
|
||||
| 阶段 | 阶段 | 项目名称 | 引导语及教学方法 | 队形/站位/留意点 | 目标心率区间 | 时间(分) |
|
||||
| ---------- | -------- | ----------------------------- | ------------------------ | --------------------- | ------------ | ---------- |
|
||||
| **准备部分** | 热身 | | | | | 3 |
|
||||
| | 注意力游戏 | | | | | 3 |
|
||||
| **正课部分** | 基本素质练习及常规意识培养环节 | | | | | 5 |
|
||||
| | 复习环节 | | | | | 5 |
|
||||
| | 新授环节 | | | | | 8 |
|
||||
| **结束部分** | 社会性及情感目标游戏 | | | | | 4 |
|
||||
| | 整理放松 | | | | | 2 |
|
||||
|
||||
|
||||
请以专业体育教师的视角,提供详细的数据分析和教学建议。`, teachingPlanContent, heartRateContent)
|
||||
}
|
||||
|
||||
// 调用大模型进行分析
|
||||
func analyzeClassData(teachingPlanFile, heartRateFile string) (string, error) {
|
||||
// 读取文件内容
|
||||
teachingPlanContent, err := readFileContent(teachingPlanFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
heartRateContent, err := readFileContent(heartRateFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 构建客户端
|
||||
config := openai.DefaultConfig(APIKey)
|
||||
config.BaseURL = BaseURL
|
||||
client := openai.NewClientWithConfig(config)
|
||||
|
||||
// 构建提示词
|
||||
prompt := buildAnalysisPrompt(teachingPlanContent, heartRateContent)
|
||||
|
||||
// 调用API
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: Model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: prompt,
|
||||
},
|
||||
},
|
||||
Temperature: 0.6, // 使用默认值
|
||||
TopP: 0.6, // 使用默认值
|
||||
MaxTokens: 5000, // 适当限制输出长度
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("API调用失败: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("未收到有效响应")
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// 保存分析结果到文件
|
||||
func saveAnalysisResult(result, outputFile string) error {
|
||||
// 添加时间戳和分隔符
|
||||
timestamp := "生成时间: " + getCurrentTime()
|
||||
separator := strings.Repeat("=", 80)
|
||||
|
||||
formattedResult := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n",
|
||||
separator, timestamp, separator, result, separator)
|
||||
|
||||
return ioutil.WriteFile(outputFile, []byte(formattedResult), 0644)
|
||||
}
|
||||
|
||||
// 获取当前时间(简化版)
|
||||
func getCurrentTime() string {
|
||||
// 实际使用时可以导入time包
|
||||
return time.Now().Format("2006-01-02 15:04:05") // 替换为 time.Now().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 文件路径配置
|
||||
teachingPlanFile := "D:\\projects\\IdeaProjects\\hr_receiver\\test\\b.md"
|
||||
heartRateFile := "D:\\projects\\IdeaProjects\\hr_receiver\\test\\b.csv"
|
||||
outputFile := "小班.md"
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(teachingPlanFile); os.IsNotExist(err) {
|
||||
log.Fatalf("教案文件不存在: %s", teachingPlanFile)
|
||||
}
|
||||
if _, err := os.Stat(heartRateFile); os.IsNotExist(err) {
|
||||
log.Fatalf("心率数据文件不存在: %s", heartRateFile)
|
||||
}
|
||||
|
||||
fmt.Println("开始分析体育课堂数据...")
|
||||
fmt.Printf("教案文件: %s\n", teachingPlanFile)
|
||||
fmt.Printf("心率数据: %s\n", heartRateFile)
|
||||
|
||||
// 进行分析
|
||||
result, err := analyzeClassData(teachingPlanFile, heartRateFile)
|
||||
if err != nil {
|
||||
log.Fatalf("分析失败: %v", err)
|
||||
}
|
||||
|
||||
// 保存结果
|
||||
err = saveAnalysisResult(result, outputFile)
|
||||
if err != nil {
|
||||
log.Fatalf("保存结果失败: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("分析完成!结果已保存到: %s\n", outputFile)
|
||||
fmt.Println("\n分析报告摘要:")
|
||||
fmt.Println(strings.Repeat("-", 50))
|
||||
|
||||
// 显示前200字符作为预览
|
||||
preview := result
|
||||
if len(result) > 200 {
|
||||
preview = result[:200] + "..."
|
||||
}
|
||||
fmt.Println(preview)
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
docx "github.com/fumiama/go-docx"
|
||||
)
|
||||
|
||||
func DocxToStructuredPrompt(filename string) (string, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
doc, err := docx.Parse(f, fi.Size())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("# 文件:%s\n\n", filename))
|
||||
|
||||
for _, item := range doc.Document.Body.Items {
|
||||
switch v := item.(type) {
|
||||
case *docx.Paragraph:
|
||||
// 直接用 fmt.Sprint 利用庫的 Stringer
|
||||
text := fmt.Sprint(v)
|
||||
text = strings.TrimSpace(text)
|
||||
if text != "" {
|
||||
sb.WriteString(text + "\n\n")
|
||||
}
|
||||
|
||||
case *docx.Table:
|
||||
sb.WriteString("## 表格\n")
|
||||
|
||||
// 先印表頭(可選)
|
||||
sb.WriteString("| ")
|
||||
|
||||
// 假設第一行是表頭(很多文件如此),或全部當內容
|
||||
for i, row := range v.TableRows {
|
||||
var cells []string
|
||||
for _, cell := range row.TableCells {
|
||||
// 這裡是重點:cell 本身沒有 String(),但可以遍歷它的 Paragraphs
|
||||
var cellText strings.Builder
|
||||
for _, p := range cell.Paragraphs {
|
||||
cellText.WriteString(fmt.Sprint(p))
|
||||
cellText.WriteString(" ")
|
||||
}
|
||||
cells = append(cells, strings.TrimSpace(cellText.String()))
|
||||
}
|
||||
|
||||
sb.WriteString(strings.Join(cells, " | "))
|
||||
sb.WriteString(" |\n")
|
||||
|
||||
// 如果想加 markdown 表頭分隔線(只在第一行後加)
|
||||
if i == 0 {
|
||||
sb.WriteString("| " + strings.Repeat("--- | ", len(cells)) + "\n")
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
default:
|
||||
// 忽略圖片、頁首等
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func main1() {
|
||||
// 測試用
|
||||
prompt, err := docxToStructuredPrompt("D:\\myDocument\\tencent\\weChat\\WeChat Files\\wxid_pv6rg3z2l28y22\\FileStorage\\File\\2026-01\\(改)小班体育活动《蚂蚁运粮》(泉秀实幼吴思莹).docx")
|
||||
if err != nil {
|
||||
fmt.Println("錯誤:", err)
|
||||
return
|
||||
}
|
||||
fmt.Println(prompt)
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
docx "github.com/fumiama/go-docx"
|
||||
)
|
||||
|
||||
func DocxToStructuredPrompt(filename string) (string, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
doc, err := docx.Parse(f, fi.Size())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("# 文件:%s\n\n", filename))
|
||||
|
||||
for _, item := range doc.Document.Body.Items {
|
||||
switch v := item.(type) {
|
||||
case *docx.Paragraph:
|
||||
// 直接用 fmt.Sprint 利用庫的 Stringer
|
||||
text := fmt.Sprint(v)
|
||||
text = strings.TrimSpace(text)
|
||||
if text != "" {
|
||||
sb.WriteString(text + "\n\n")
|
||||
}
|
||||
|
||||
case *docx.Table:
|
||||
sb.WriteString("## 表格\n")
|
||||
|
||||
// 先印表頭(可選)
|
||||
sb.WriteString("| ")
|
||||
|
||||
// 假設第一行是表頭(很多文件如此),或全部當內容
|
||||
for i, row := range v.TableRows {
|
||||
var cells []string
|
||||
for _, cell := range row.TableCells {
|
||||
// 這裡是重點:cell 本身沒有 String(),但可以遍歷它的 Paragraphs
|
||||
var cellText strings.Builder
|
||||
for _, p := range cell.Paragraphs {
|
||||
cellText.WriteString(fmt.Sprint(p))
|
||||
cellText.WriteString(" ")
|
||||
}
|
||||
cells = append(cells, strings.TrimSpace(cellText.String()))
|
||||
}
|
||||
|
||||
sb.WriteString(strings.Join(cells, " | "))
|
||||
sb.WriteString(" |\n")
|
||||
|
||||
// 如果想加 markdown 表頭分隔線(只在第一行後加)
|
||||
if i == 0 {
|
||||
sb.WriteString("| " + strings.Repeat("--- | ", len(cells)) + "\n")
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
default:
|
||||
// 忽略圖片、頁首等
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
+56
@@ -0,0 +1,56 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var ApiSecret = "your-super-secret-key" // 预共享密钥
|
||||
type Claims struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT Token
|
||||
func GenerateToken(userID uint, username string) (string, error) {
|
||||
expirationTime := time.Now().Add(24 * 30 * time.Hour) // Token有效期24小时
|
||||
//expirationTime := time.Now().Add(1 * time.Second) // Token有效期24小时
|
||||
|
||||
claims := &Claims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expirationTime),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: "your-app-name",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(ApiSecret))
|
||||
|
||||
return tokenString, err
|
||||
}
|
||||
|
||||
// ParseToken 解析JWT Token
|
||||
func ParseToken(tokenStr string) (*Claims, error) {
|
||||
claims := &Claims{}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(ApiSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
Reference in New Issue
Block a user