// main.go package main import ( "errors" "fmt" _ "fmt" "github.com/gin-gonic/gin" "github.com/lib/pq" "gorm.io/driver/postgres" "gorm.io/gorm/clause" "log" "net/http" "net/url" "regexp" "sort" "time" "gorm.io/gorm" ) // 调整后的结构体定义 type CreateTrainingRequest struct { TID uint `json:"tid" binding:"required"` StartTime int64 `json:"startTime" binding:"required,min=1609459200000"` // 2021-01-01起 EndTime int64 `json:"endTime" binding:"required,min=1609459200000"` Name string `json:"name" binding:"required"` MaxHeartRate int `json:"maxHeartRate" binding:"required,min=30,max=250"` Duration int `json:"duration" binding:"required"` // 单位:秒 PeopleNum int `json:"peopleNum" binding:"required"` Evaluation string `json:"evaluation" binding:"required"` BeltAddrs []string `json:"beltAddrs" binding:"required"` } type HeartRateUploadRequest struct { TrainID uint `json:"trainId" binding:"required"` Data []HeartRateData `json:"data" binding:"required,dive"` } type HeartRateData struct { BeltAddr string `json:"beltAddr" binding:"required"` Timestamp int64 `json:"timestamp" binding:"required,min=1609459200000"` Value int `json:"value" binding:"required,min=30,max=250"` LastValue int `json:"lastValue" binding:"required,min=30,max=250"` } // 数据库模型 type TrainingRecord struct { TID uint `gorm:"primaryKey;column:tid" json:"tid"` StartTime time.Time `gorm:"not null" json:"-"` EndTime time.Time `gorm:"not null" json:"-"` Name string `gorm:"type:varchar(255);default:'训练'" json:"name"` MaxHeartRate int `gorm:"not null" json:"maxHeartRate"` Duration int `gorm:"not null" json:"duration"` PeopleNum int `gorm:"not null" json:"peopleNum"` Evaluation string `gorm:"type:varchar(255);default:'适中'" json:"evaluation"` BeltAddresses []string `gorm:"type:text[]" json:"beltAddrs"` // 添加毫秒时间戳字段(仅用于JSON序列化) StartTimestamp int64 `gorm:"-" json:"startTime"` EndTimestamp int64 `gorm:"-" json:"endTime"` } type HeartRate struct { Time time.Time `gorm:"primaryKey;type:timestamptz" json:"time"` TrainID uint `gorm:"primaryKey;index" json:"train_id"` BeltAddr string `gorm:"type:varchar(255);not null;index" json:"belt_addr"` Value int `gorm:"check:value BETWEEN 30 AND 250" json:"value"` LastValue int `gorm:"check:last_value BETWEEN 30 AND 250" json:"last_value"` } // 统计分析响应结构 type TrainingReport struct { TrainID uint `json:"train_id"` AvgHeartRate float64 `json:"avg_heart_rate"` MaxHeartRate int `json:"max_heart_rate"` DangerSeconds int `json:"danger_seconds"` BeltStats map[string]Stats `json:"belt_stats"` TimeSeries []TimePoint `json:"time_series,omitempty"` } type Stats struct { Avg float64 `json:"avg"` Max int `json:"max"` DangerCount int `json:"danger_count"` } type TimePoint struct { Time time.Time `json:"time"` AvgValue float64 `json:"avg_value"` DangerCount int `json:"danger_count"` } // 实现自定义序列化逻辑 func (t *TrainingRecord) AfterFind(tx *gorm.DB) (err error) { t.StartTimestamp = t.StartTime.UnixNano() / int64(time.Millisecond) t.EndTimestamp = t.EndTime.UnixNano() / int64(time.Millisecond) return } func (t *TrainingRecord) BeforeCreate(tx *gorm.DB) (err error) { t.StartTime = time.Unix(0, t.StartTimestamp*int64(time.Millisecond)).UTC() t.EndTime = time.Unix(0, t.EndTimestamp*int64(time.Millisecond)).UTC() return } // 更新处理函数 //func createTraining(c *gin.Context) { // var req CreateTrainingRequest // if err := c.ShouldBindJSON(&req); err != nil { // c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) // return // } // // record := TrainingRecord{ // TID: req.TID, // StartTimestamp: req.StartTime, // EndTimestamp: req.EndTime, // Name: req.Name, // MaxHeartRate: req.MaxHeartRate, // Duration: req.Duration, // PeopleNum: req.PeopleNum, // Evaluation: req.Evaluation, // BeltAddresses: req.BeltAddrs, // } // // if err := db.Create(&record).Error; err != nil { // if isDuplicateKeyError(err) { // c.JSON(http.StatusConflict, gin.H{"error": "训练记录已存在"}) // return // } // c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) // return // } // // c.JSON(http.StatusCreated, record) //} // 统一时间处理函数 func parseTimestamp(ms int64) (time.Time, error) { if ms < 1609459200000 { // 2021-01-01 00:00:00 UTC return time.Time{}, fmt.Errorf("无效的时间戳") } return time.Unix(0, ms*int64(time.Millisecond)).UTC(), nil } // 更新心率数据处理 func processHeartRateData(req *HeartRateUploadRequest) ([]HeartRate, error) { var rates []HeartRate for _, d := range req.Data { t, err := parseTimestamp(d.Timestamp) if err != nil { return nil, fmt.Errorf("无效的时间戳: %d", d.Timestamp) } rates = append(rates, HeartRate{ TrainID: req.TrainID, BeltAddr: d.BeltAddr, Time: t, Value: d.Value, LastValue: d.LastValue, }) } // 按时间排序 sort.Slice(rates, func(i, j int) bool { return rates[i].Time.Before(rates[j].Time) }) return rates, nil } // 更新自动迁移逻辑 func autoMigrate(db *gorm.DB) { db.Set("gorm:table_options", " comment '训练记录表'").AutoMigrate(&TrainingRecord{}) db.Set("gorm:table_options", " comment '心率数据表'").AutoMigrate(&HeartRate{}) // 创建复合索引 db.Exec(` CREATE INDEX IF NOT EXISTS idx_heart_rates_main ON heart_rates (train_id, belt_addr, time DESC) `) } var db *gorm.DB const ( defaultDB = "postgres" // 用于创建新数据库的默认数据库 ) func initDB() { // 解析原始DSN dsn := "host=localhost user=postgres password=root dbname=training port=5432 sslmode=disable" parsedDSN, err := url.Parse(dsn) if err != nil { log.Fatal("Invalid DSN:", err) } query := parsedDSN.Query() query.Set("dbname", defaultDB) query.Set("sslmode", "disable") query.Set("port", "5432") query.Set("host", "localhost") query.Set("user", "postgres") query.Set("password", "root") defaultDSN := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s", query.Get("host"), query.Get("user"), query.Get("password"), query.Get("dbname"), query.Get("port"), query.Get("sslmode"), ) //defaultDSNUrl, _ := url.Parse(defaultDSN) // 提取数据库名称 dbName := "training" // 第一步:尝试连接目标数据库 db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) if err != nil { // 检查是否是数据库不存在的错误(PostgreSQL错误码3D000) if isDatabaseNotExistError(err) { log.Printf("Database %q does not exist, attempting to create...", dbName) createDatabase(defaultDSN, dbName) } else { log.Fatal("Database connection failed:", err) } } else { log.Printf("Database %q already exists", dbName) sqlDB, _ := db.DB() sqlDB.Close() } // 再次连接目标数据库 db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ CreateBatchSize: 1000, }) if err != nil { log.Fatal("Final database connection failed:", err) } // 自动迁移表结构 autoMigrate(db) } // 检查是否为数据库不存在错误 func isDatabaseNotExistError(err error) bool { if err == nil { return false } errStr := err.Error() return containsErrorCode(errStr, "3D000") } func containsErrorCode(errStr, code string) bool { // 使用正则表达式检查错误字符串中是否包含指定的错误码 re := regexp.MustCompile(`\b` + code + `\b`) return re.MatchString(errStr) } // 创建新数据库 func createDatabase(parsedDSN string, dbName string) { // 使用默认数据库连接 //parsedDSN.Path = "/" + defaultDB defaultDSN := parsedDSN // 创建数据库 db, err := gorm.Open(postgres.Open(defaultDSN), &gorm.Config{}) if err != nil { log.Fatal("Connect to default database failed:", err) } // 需要超级用户权限才能创建数据库 createSQL := fmt.Sprintf("CREATE DATABASE \"%s\"", dbName) if err := db.Exec(createSQL).Error; err != nil { log.Fatal("Create database failed:", err) } sqlDB, _ := db.DB() sqlDB.Close() log.Printf("Database %q created successfully", dbName) } // 自动迁移表结构 //func autoMigrate(db *gorm.DB) { // err := db.AutoMigrate( // &TrainingRecord{}, // &HeartRate{}, // ) // if err != nil { // log.Fatal("Auto migrate failed:", err) // } // // // 添加索引(生产环境建议使用迁移工具) // db.Exec("CREATE INDEX IF NOT EXISTS idx_heart_rates_train_time ON heart_rates (train_id, time)") // log.Println("Database schema initialized successfully") //} // ...(保持之前的import和结构体定义不变) // // func initDB() { // dsn := "host=localhost user=postgres password=root dbname=training port=5432 sslmode=disable" // var err error // db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ // CreateBatchSize: 1000, // }) // if err != nil { // log.Fatal("Failed to connect to database:", err) // } // // // 自动迁移(生产环境建议使用迁移工具) // db.AutoMigrate(&TrainingRecord{}, &HeartRate{}) // } func setupRouter() *gin.Engine { r := gin.Default() // 训练记录路由组 trainingGroup := r.Group("/api/trainings") { trainingGroup.POST("create", createTraining) trainingGroup.POST("/:id/heart-rates", uploadHeartRates) //trainingGroup.GET("/:id/report", getTrainingReport) //trainingGroup.GET("/:id/time-series", getTimeSeriesData) } return r } // 更新处理函数 func createTraining(c *gin.Context) { var req CreateTrainingRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } record := TrainingRecord{ TID: req.TID, StartTimestamp: req.StartTime, EndTimestamp: req.EndTime, Name: req.Name, MaxHeartRate: req.MaxHeartRate, Duration: req.Duration, PeopleNum: req.PeopleNum, Evaluation: req.Evaluation, BeltAddresses: req.BeltAddrs, } if err := db.Create(&record).Error; err != nil { if isDuplicateKeyError(err) { c.JSON(http.StatusConflict, gin.H{"error": "训练记录已存在"}) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusCreated, record) } // 判断是否为唯一键冲突错误 func isDuplicateKeyError(err error) bool { if err == nil { return false } var pqErr *pq.Error if errors.As(err, &pqErr) { return pqErr.Code == "23505" // 唯一键冲突错误码 } return false } // 更新上传心率处理 func uploadHeartRates(c *gin.Context) { var req HeartRateUploadRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // 时区处理(接受时区参数,默认为UTC) loc := time.UTC if tz := c.Query("tz"); tz != "" { if l, err := time.LoadLocation(tz); err == nil { loc = l } } heartRates := make([]HeartRate, 0, len(req.Data)) for _, d := range req.Data { t, err := parseTimestamp(d.Timestamp) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ "error": fmt.Sprintf("invalid timestamp: %v", d.Timestamp), }) return } // 转换为指定时区 t = t.In(loc) heartRates = append(heartRates, HeartRate{ TrainID: req.TrainID, BeltAddr: d.BeltAddr, Time: t, Value: d.Value, LastValue: d.LastValue, }) } // 批量插入优化 err := db.Clauses( clause.OnConflict{ Columns: []clause.Column{{Name: "train_id"}, {Name: "time"}}, DoNothing: true, }, clause.Returning{}, ).CreateInBatches(heartRates, 1000).Error if err != nil { log.Printf("Batch insert error: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "数据存储失败"}) return } c.JSON(http.StatusCreated, gin.H{ "message": fmt.Sprintf("成功存储 %d 条心率数据", len(heartRates)), "timezone": loc.String(), }) } // 获取训练报告 func getTrainingReport(c *gin.Context) { trainID := parseUint(c.Param("id")) threshold := c.DefaultQuery("threshold", "120") // 默认危险阈值120 var report TrainingReport report.BeltStats = make(map[string]Stats) // 获取基础统计信息 baseQuery := db.Model(&HeartRate{}).Where("train_id = ?", trainID) // 整体平均和最大心率 baseQuery.Select("AVG(value) as avg, MAX(value) as max"). Row().Scan(&report.AvgHeartRate, &report.MaxHeartRate) // 危险时长计算(假设5秒一个数据点) var dangerCount int64 baseQuery.Where("value >= ?", threshold).Count(&dangerCount) report.DangerSeconds = int(dangerCount) * 5 // 各腰带统计 rows, err := db.Model(&HeartRate{}). Select("belt_addr, AVG(value) as avg, MAX(value) as max, COUNT(*) filter (where value >= ?) as danger", threshold). Where("train_id = ?", trainID). Group("belt_addr"). Rows() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } defer rows.Close() for rows.Next() { var addr string var stat Stats rows.Scan(&addr, &stat.Avg, &stat.Max, &stat.DangerCount) report.BeltStats[addr] = stat } report.TrainID = trainID c.JSON(http.StatusOK, report) } // 获取时间序列数据 func getTimeSeriesData(c *gin.Context) { trainID := parseUint(c.Param("id")) interval := c.DefaultQuery("interval", "1m") // 默认1分钟间隔 var points []TimePoint err := db.Raw(` SELECT date_trunc(?, time) as time, AVG(value) as avg_value, COUNT(*) FILTER (WHERE value >= 120) as danger_count FROM heart_rates WHERE train_id = ? GROUP BY 1 ORDER BY 1 `, interval, trainID).Scan(&points).Error if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, points) } // 辅助函数:字符串转uint func parseUint(s string) uint { var n uint fmt.Sscanf(s, "%d", &n) return n } // 补全main函数 func main() { initDB() r := setupRouter() log.Fatal(r.Run(":8081")) }