517 lines
14 KiB
Go
517 lines
14 KiB
Go
// main.go
|
||
package main
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
_ "fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/lib/pq"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/gorm/clause"
|
||
"log"
|
||
"net/http"
|
||
"net/url"
|
||
"regexp"
|
||
"sort"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// 调整后的结构体定义
|
||
type CreateTrainingRequest struct {
|
||
TID uint `json:"tid" binding:"required"`
|
||
StartTime int64 `json:"startTime" binding:"required,min=1609459200000"` // 2021-01-01起
|
||
EndTime int64 `json:"endTime" binding:"required,min=1609459200000"`
|
||
Name string `json:"name" binding:"required"`
|
||
MaxHeartRate int `json:"maxHeartRate" binding:"required,min=30,max=250"`
|
||
Duration int `json:"duration" binding:"required"` // 单位:秒
|
||
PeopleNum int `json:"peopleNum" binding:"required"`
|
||
Evaluation string `json:"evaluation" binding:"required"`
|
||
BeltAddrs []string `json:"beltAddrs" binding:"required"`
|
||
}
|
||
|
||
type HeartRateUploadRequest struct {
|
||
TrainID uint `json:"trainId" binding:"required"`
|
||
Data []HeartRateData `json:"data" binding:"required,dive"`
|
||
}
|
||
|
||
type HeartRateData struct {
|
||
BeltAddr string `json:"beltAddr" binding:"required"`
|
||
Timestamp int64 `json:"timestamp" binding:"required,min=1609459200000"`
|
||
Value int `json:"value" binding:"required,min=30,max=250"`
|
||
LastValue int `json:"lastValue" binding:"required,min=30,max=250"`
|
||
}
|
||
|
||
// 数据库模型
|
||
type TrainingRecord struct {
|
||
TID uint `gorm:"primaryKey;column:tid" json:"tid"`
|
||
StartTime time.Time `gorm:"not null" json:"-"`
|
||
EndTime time.Time `gorm:"not null" json:"-"`
|
||
Name string `gorm:"type:varchar(255);default:'训练'" json:"name"`
|
||
MaxHeartRate int `gorm:"not null" json:"maxHeartRate"`
|
||
Duration int `gorm:"not null" json:"duration"`
|
||
PeopleNum int `gorm:"not null" json:"peopleNum"`
|
||
Evaluation string `gorm:"type:varchar(255);default:'适中'" json:"evaluation"`
|
||
BeltAddresses []string `gorm:"type:text[]" json:"beltAddrs"`
|
||
|
||
// 添加毫秒时间戳字段(仅用于JSON序列化)
|
||
StartTimestamp int64 `gorm:"-" json:"startTime"`
|
||
EndTimestamp int64 `gorm:"-" json:"endTime"`
|
||
}
|
||
|
||
type HeartRate struct {
|
||
Time time.Time `gorm:"primaryKey;type:timestamptz" json:"time"`
|
||
TrainID uint `gorm:"primaryKey;index" json:"train_id"`
|
||
BeltAddr string `gorm:"type:varchar(255);not null;index" json:"belt_addr"`
|
||
Value int `gorm:"check:value BETWEEN 30 AND 250" json:"value"`
|
||
LastValue int `gorm:"check:last_value BETWEEN 30 AND 250" json:"last_value"`
|
||
}
|
||
|
||
// 统计分析响应结构
|
||
type TrainingReport struct {
|
||
TrainID uint `json:"train_id"`
|
||
AvgHeartRate float64 `json:"avg_heart_rate"`
|
||
MaxHeartRate int `json:"max_heart_rate"`
|
||
DangerSeconds int `json:"danger_seconds"`
|
||
BeltStats map[string]Stats `json:"belt_stats"`
|
||
TimeSeries []TimePoint `json:"time_series,omitempty"`
|
||
}
|
||
|
||
type Stats struct {
|
||
Avg float64 `json:"avg"`
|
||
Max int `json:"max"`
|
||
DangerCount int `json:"danger_count"`
|
||
}
|
||
|
||
type TimePoint struct {
|
||
Time time.Time `json:"time"`
|
||
AvgValue float64 `json:"avg_value"`
|
||
DangerCount int `json:"danger_count"`
|
||
}
|
||
|
||
// 实现自定义序列化逻辑
|
||
func (t *TrainingRecord) AfterFind(tx *gorm.DB) (err error) {
|
||
t.StartTimestamp = t.StartTime.UnixNano() / int64(time.Millisecond)
|
||
t.EndTimestamp = t.EndTime.UnixNano() / int64(time.Millisecond)
|
||
return
|
||
}
|
||
|
||
func (t *TrainingRecord) BeforeCreate(tx *gorm.DB) (err error) {
|
||
t.StartTime = time.Unix(0, t.StartTimestamp*int64(time.Millisecond)).UTC()
|
||
t.EndTime = time.Unix(0, t.EndTimestamp*int64(time.Millisecond)).UTC()
|
||
return
|
||
}
|
||
|
||
// 更新处理函数
|
||
//func createTraining(c *gin.Context) {
|
||
// var req CreateTrainingRequest
|
||
// if err := c.ShouldBindJSON(&req); err != nil {
|
||
// c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
// return
|
||
// }
|
||
//
|
||
// record := TrainingRecord{
|
||
// TID: req.TID,
|
||
// StartTimestamp: req.StartTime,
|
||
// EndTimestamp: req.EndTime,
|
||
// Name: req.Name,
|
||
// MaxHeartRate: req.MaxHeartRate,
|
||
// Duration: req.Duration,
|
||
// PeopleNum: req.PeopleNum,
|
||
// Evaluation: req.Evaluation,
|
||
// BeltAddresses: req.BeltAddrs,
|
||
// }
|
||
//
|
||
// if err := db.Create(&record).Error; err != nil {
|
||
// if isDuplicateKeyError(err) {
|
||
// c.JSON(http.StatusConflict, gin.H{"error": "训练记录已存在"})
|
||
// return
|
||
// }
|
||
// c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
// return
|
||
// }
|
||
//
|
||
// c.JSON(http.StatusCreated, record)
|
||
//}
|
||
|
||
// 统一时间处理函数
|
||
func parseTimestamp(ms int64) (time.Time, error) {
|
||
if ms < 1609459200000 { // 2021-01-01 00:00:00 UTC
|
||
return time.Time{}, fmt.Errorf("无效的时间戳")
|
||
}
|
||
return time.Unix(0, ms*int64(time.Millisecond)).UTC(), nil
|
||
}
|
||
|
||
// 更新心率数据处理
|
||
func processHeartRateData(req *HeartRateUploadRequest) ([]HeartRate, error) {
|
||
var rates []HeartRate
|
||
|
||
for _, d := range req.Data {
|
||
t, err := parseTimestamp(d.Timestamp)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("无效的时间戳: %d", d.Timestamp)
|
||
}
|
||
|
||
rates = append(rates, HeartRate{
|
||
TrainID: req.TrainID,
|
||
BeltAddr: d.BeltAddr,
|
||
Time: t,
|
||
Value: d.Value,
|
||
LastValue: d.LastValue,
|
||
})
|
||
}
|
||
|
||
// 按时间排序
|
||
sort.Slice(rates, func(i, j int) bool {
|
||
return rates[i].Time.Before(rates[j].Time)
|
||
})
|
||
|
||
return rates, nil
|
||
}
|
||
|
||
// 更新自动迁移逻辑
|
||
func autoMigrate(db *gorm.DB) {
|
||
db.Set("gorm:table_options", " comment '训练记录表'").AutoMigrate(&TrainingRecord{})
|
||
db.Set("gorm:table_options", " comment '心率数据表'").AutoMigrate(&HeartRate{})
|
||
|
||
// 创建复合索引
|
||
db.Exec(`
|
||
CREATE INDEX IF NOT EXISTS idx_heart_rates_main
|
||
ON heart_rates (train_id, belt_addr, time DESC)
|
||
`)
|
||
}
|
||
|
||
var db *gorm.DB
|
||
|
||
const (
|
||
defaultDB = "postgres" // 用于创建新数据库的默认数据库
|
||
)
|
||
|
||
func initDB() {
|
||
// 解析原始DSN
|
||
dsn := "host=localhost user=postgres password=root dbname=training port=5432 sslmode=disable"
|
||
parsedDSN, err := url.Parse(dsn)
|
||
if err != nil {
|
||
log.Fatal("Invalid DSN:", err)
|
||
}
|
||
query := parsedDSN.Query()
|
||
query.Set("dbname", defaultDB)
|
||
query.Set("sslmode", "disable")
|
||
query.Set("port", "5432")
|
||
query.Set("host", "localhost")
|
||
query.Set("user", "postgres")
|
||
query.Set("password", "root")
|
||
defaultDSN := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s",
|
||
query.Get("host"),
|
||
query.Get("user"),
|
||
query.Get("password"),
|
||
query.Get("dbname"),
|
||
query.Get("port"),
|
||
query.Get("sslmode"),
|
||
)
|
||
//defaultDSNUrl, _ := url.Parse(defaultDSN)
|
||
|
||
// 提取数据库名称
|
||
dbName := "training"
|
||
|
||
// 第一步:尝试连接目标数据库
|
||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||
if err != nil {
|
||
// 检查是否是数据库不存在的错误(PostgreSQL错误码3D000)
|
||
if isDatabaseNotExistError(err) {
|
||
log.Printf("Database %q does not exist, attempting to create...", dbName)
|
||
createDatabase(defaultDSN, dbName)
|
||
} else {
|
||
log.Fatal("Database connection failed:", err)
|
||
}
|
||
} else {
|
||
log.Printf("Database %q already exists", dbName)
|
||
sqlDB, _ := db.DB()
|
||
sqlDB.Close()
|
||
}
|
||
|
||
// 再次连接目标数据库
|
||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||
CreateBatchSize: 1000,
|
||
})
|
||
if err != nil {
|
||
log.Fatal("Final database connection failed:", err)
|
||
}
|
||
|
||
// 自动迁移表结构
|
||
autoMigrate(db)
|
||
}
|
||
|
||
// 检查是否为数据库不存在错误
|
||
func isDatabaseNotExistError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
errStr := err.Error()
|
||
return containsErrorCode(errStr, "3D000")
|
||
}
|
||
|
||
func containsErrorCode(errStr, code string) bool {
|
||
// 使用正则表达式检查错误字符串中是否包含指定的错误码
|
||
re := regexp.MustCompile(`\b` + code + `\b`)
|
||
return re.MatchString(errStr)
|
||
}
|
||
|
||
// 创建新数据库
|
||
func createDatabase(parsedDSN string, dbName string) {
|
||
// 使用默认数据库连接
|
||
//parsedDSN.Path = "/" + defaultDB
|
||
defaultDSN := parsedDSN
|
||
|
||
// 创建数据库
|
||
db, err := gorm.Open(postgres.Open(defaultDSN), &gorm.Config{})
|
||
if err != nil {
|
||
log.Fatal("Connect to default database failed:", err)
|
||
}
|
||
|
||
// 需要超级用户权限才能创建数据库
|
||
createSQL := fmt.Sprintf("CREATE DATABASE \"%s\"", dbName)
|
||
if err := db.Exec(createSQL).Error; err != nil {
|
||
log.Fatal("Create database failed:", err)
|
||
}
|
||
|
||
sqlDB, _ := db.DB()
|
||
sqlDB.Close()
|
||
log.Printf("Database %q created successfully", dbName)
|
||
}
|
||
|
||
// 自动迁移表结构
|
||
//func autoMigrate(db *gorm.DB) {
|
||
// err := db.AutoMigrate(
|
||
// &TrainingRecord{},
|
||
// &HeartRate{},
|
||
// )
|
||
// if err != nil {
|
||
// log.Fatal("Auto migrate failed:", err)
|
||
// }
|
||
//
|
||
// // 添加索引(生产环境建议使用迁移工具)
|
||
// db.Exec("CREATE INDEX IF NOT EXISTS idx_heart_rates_train_time ON heart_rates (train_id, time)")
|
||
// log.Println("Database schema initialized successfully")
|
||
//}
|
||
|
||
// ...(保持之前的import和结构体定义不变)
|
||
//
|
||
// func initDB() {
|
||
// dsn := "host=localhost user=postgres password=root dbname=training port=5432 sslmode=disable"
|
||
// var err error
|
||
// db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||
// CreateBatchSize: 1000,
|
||
// })
|
||
// if err != nil {
|
||
// log.Fatal("Failed to connect to database:", err)
|
||
// }
|
||
//
|
||
// // 自动迁移(生产环境建议使用迁移工具)
|
||
// db.AutoMigrate(&TrainingRecord{}, &HeartRate{})
|
||
// }
|
||
func setupRouter() *gin.Engine {
|
||
r := gin.Default()
|
||
|
||
// 训练记录路由组
|
||
trainingGroup := r.Group("/api/trainings")
|
||
{
|
||
trainingGroup.POST("create", createTraining)
|
||
trainingGroup.POST("/:id/heart-rates", uploadHeartRates)
|
||
//trainingGroup.GET("/:id/report", getTrainingReport)
|
||
//trainingGroup.GET("/:id/time-series", getTimeSeriesData)
|
||
}
|
||
|
||
return r
|
||
}
|
||
|
||
// 更新处理函数
|
||
func createTraining(c *gin.Context) {
|
||
var req CreateTrainingRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
record := TrainingRecord{
|
||
TID: req.TID,
|
||
StartTimestamp: req.StartTime,
|
||
EndTimestamp: req.EndTime,
|
||
Name: req.Name,
|
||
MaxHeartRate: req.MaxHeartRate,
|
||
Duration: req.Duration,
|
||
PeopleNum: req.PeopleNum,
|
||
Evaluation: req.Evaluation,
|
||
BeltAddresses: req.BeltAddrs,
|
||
}
|
||
|
||
if err := db.Create(&record).Error; err != nil {
|
||
if isDuplicateKeyError(err) {
|
||
c.JSON(http.StatusConflict, gin.H{"error": "训练记录已存在"})
|
||
return
|
||
}
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusCreated, record)
|
||
}
|
||
|
||
// 判断是否为唯一键冲突错误
|
||
func isDuplicateKeyError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
|
||
var pqErr *pq.Error
|
||
if errors.As(err, &pqErr) {
|
||
return pqErr.Code == "23505" // 唯一键冲突错误码
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// 更新上传心率处理
|
||
func uploadHeartRates(c *gin.Context) {
|
||
var req HeartRateUploadRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// 时区处理(接受时区参数,默认为UTC)
|
||
loc := time.UTC
|
||
if tz := c.Query("tz"); tz != "" {
|
||
if l, err := time.LoadLocation(tz); err == nil {
|
||
loc = l
|
||
}
|
||
}
|
||
|
||
heartRates := make([]HeartRate, 0, len(req.Data))
|
||
for _, d := range req.Data {
|
||
t, err := parseTimestamp(d.Timestamp)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"error": fmt.Sprintf("invalid timestamp: %v", d.Timestamp),
|
||
})
|
||
return
|
||
}
|
||
|
||
// 转换为指定时区
|
||
t = t.In(loc)
|
||
|
||
heartRates = append(heartRates, HeartRate{
|
||
TrainID: req.TrainID,
|
||
BeltAddr: d.BeltAddr,
|
||
Time: t,
|
||
Value: d.Value,
|
||
LastValue: d.LastValue,
|
||
})
|
||
}
|
||
|
||
// 批量插入优化
|
||
err := db.Clauses(
|
||
clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "train_id"}, {Name: "time"}},
|
||
DoNothing: true,
|
||
},
|
||
clause.Returning{},
|
||
).CreateInBatches(heartRates, 1000).Error
|
||
|
||
if err != nil {
|
||
log.Printf("Batch insert error: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据存储失败"})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusCreated, gin.H{
|
||
"message": fmt.Sprintf("成功存储 %d 条心率数据", len(heartRates)),
|
||
"timezone": loc.String(),
|
||
})
|
||
}
|
||
|
||
// 获取训练报告
|
||
func getTrainingReport(c *gin.Context) {
|
||
trainID := parseUint(c.Param("id"))
|
||
threshold := c.DefaultQuery("threshold", "120") // 默认危险阈值120
|
||
|
||
var report TrainingReport
|
||
report.BeltStats = make(map[string]Stats)
|
||
|
||
// 获取基础统计信息
|
||
baseQuery := db.Model(&HeartRate{}).Where("train_id = ?", trainID)
|
||
|
||
// 整体平均和最大心率
|
||
baseQuery.Select("AVG(value) as avg, MAX(value) as max").
|
||
Row().Scan(&report.AvgHeartRate, &report.MaxHeartRate)
|
||
|
||
// 危险时长计算(假设5秒一个数据点)
|
||
var dangerCount int64
|
||
baseQuery.Where("value >= ?", threshold).Count(&dangerCount)
|
||
report.DangerSeconds = int(dangerCount) * 5
|
||
|
||
// 各腰带统计
|
||
rows, err := db.Model(&HeartRate{}).
|
||
Select("belt_addr, AVG(value) as avg, MAX(value) as max, COUNT(*) filter (where value >= ?) as danger", threshold).
|
||
Where("train_id = ?", trainID).
|
||
Group("belt_addr").
|
||
Rows()
|
||
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
defer rows.Close()
|
||
|
||
for rows.Next() {
|
||
var addr string
|
||
var stat Stats
|
||
rows.Scan(&addr, &stat.Avg, &stat.Max, &stat.DangerCount)
|
||
report.BeltStats[addr] = stat
|
||
}
|
||
|
||
report.TrainID = trainID
|
||
c.JSON(http.StatusOK, report)
|
||
}
|
||
|
||
// 获取时间序列数据
|
||
func getTimeSeriesData(c *gin.Context) {
|
||
trainID := parseUint(c.Param("id"))
|
||
interval := c.DefaultQuery("interval", "1m") // 默认1分钟间隔
|
||
|
||
var points []TimePoint
|
||
err := db.Raw(`
|
||
SELECT
|
||
date_trunc(?, time) as time,
|
||
AVG(value) as avg_value,
|
||
COUNT(*) FILTER (WHERE value >= 120) as danger_count
|
||
FROM heart_rates
|
||
WHERE train_id = ?
|
||
GROUP BY 1
|
||
ORDER BY 1
|
||
`, interval, trainID).Scan(&points).Error
|
||
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, points)
|
||
}
|
||
|
||
// 辅助函数:字符串转uint
|
||
func parseUint(s string) uint {
|
||
var n uint
|
||
fmt.Sscanf(s, "%d", &n)
|
||
return n
|
||
}
|
||
|
||
// 补全main函数
|
||
func main() {
|
||
initDB()
|
||
r := setupRouter()
|
||
log.Fatal(r.Run(":8081"))
|
||
}
|