Files
hr_data_analyzer/main.go.bak
2025-03-19 14:55:12 +08:00

517 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// main.go
package main
import (
"errors"
"fmt"
_ "fmt"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
"gorm.io/driver/postgres"
"gorm.io/gorm/clause"
"log"
"net/http"
"net/url"
"regexp"
"sort"
"time"
"gorm.io/gorm"
)
// 调整后的结构体定义
type CreateTrainingRequest struct {
TID uint `json:"tid" binding:"required"`
StartTime int64 `json:"startTime" binding:"required,min=1609459200000"` // 2021-01-01起
EndTime int64 `json:"endTime" binding:"required,min=1609459200000"`
Name string `json:"name" binding:"required"`
MaxHeartRate int `json:"maxHeartRate" binding:"required,min=30,max=250"`
Duration int `json:"duration" binding:"required"` // 单位:秒
PeopleNum int `json:"peopleNum" binding:"required"`
Evaluation string `json:"evaluation" binding:"required"`
BeltAddrs []string `json:"beltAddrs" binding:"required"`
}
type HeartRateUploadRequest struct {
TrainID uint `json:"trainId" binding:"required"`
Data []HeartRateData `json:"data" binding:"required,dive"`
}
type HeartRateData struct {
BeltAddr string `json:"beltAddr" binding:"required"`
Timestamp int64 `json:"timestamp" binding:"required,min=1609459200000"`
Value int `json:"value" binding:"required,min=30,max=250"`
LastValue int `json:"lastValue" binding:"required,min=30,max=250"`
}
// 数据库模型
type TrainingRecord struct {
TID uint `gorm:"primaryKey;column:tid" json:"tid"`
StartTime time.Time `gorm:"not null" json:"-"`
EndTime time.Time `gorm:"not null" json:"-"`
Name string `gorm:"type:varchar(255);default:'训练'" json:"name"`
MaxHeartRate int `gorm:"not null" json:"maxHeartRate"`
Duration int `gorm:"not null" json:"duration"`
PeopleNum int `gorm:"not null" json:"peopleNum"`
Evaluation string `gorm:"type:varchar(255);default:'适中'" json:"evaluation"`
BeltAddresses []string `gorm:"type:text[]" json:"beltAddrs"`
// 添加毫秒时间戳字段仅用于JSON序列化
StartTimestamp int64 `gorm:"-" json:"startTime"`
EndTimestamp int64 `gorm:"-" json:"endTime"`
}
type HeartRate struct {
Time time.Time `gorm:"primaryKey;type:timestamptz" json:"time"`
TrainID uint `gorm:"primaryKey;index" json:"train_id"`
BeltAddr string `gorm:"type:varchar(255);not null;index" json:"belt_addr"`
Value int `gorm:"check:value BETWEEN 30 AND 250" json:"value"`
LastValue int `gorm:"check:last_value BETWEEN 30 AND 250" json:"last_value"`
}
// 统计分析响应结构
type TrainingReport struct {
TrainID uint `json:"train_id"`
AvgHeartRate float64 `json:"avg_heart_rate"`
MaxHeartRate int `json:"max_heart_rate"`
DangerSeconds int `json:"danger_seconds"`
BeltStats map[string]Stats `json:"belt_stats"`
TimeSeries []TimePoint `json:"time_series,omitempty"`
}
type Stats struct {
Avg float64 `json:"avg"`
Max int `json:"max"`
DangerCount int `json:"danger_count"`
}
type TimePoint struct {
Time time.Time `json:"time"`
AvgValue float64 `json:"avg_value"`
DangerCount int `json:"danger_count"`
}
// 实现自定义序列化逻辑
func (t *TrainingRecord) AfterFind(tx *gorm.DB) (err error) {
t.StartTimestamp = t.StartTime.UnixNano() / int64(time.Millisecond)
t.EndTimestamp = t.EndTime.UnixNano() / int64(time.Millisecond)
return
}
func (t *TrainingRecord) BeforeCreate(tx *gorm.DB) (err error) {
t.StartTime = time.Unix(0, t.StartTimestamp*int64(time.Millisecond)).UTC()
t.EndTime = time.Unix(0, t.EndTimestamp*int64(time.Millisecond)).UTC()
return
}
// 更新处理函数
//func createTraining(c *gin.Context) {
// var req CreateTrainingRequest
// if err := c.ShouldBindJSON(&req); err != nil {
// c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
// return
// }
//
// record := TrainingRecord{
// TID: req.TID,
// StartTimestamp: req.StartTime,
// EndTimestamp: req.EndTime,
// Name: req.Name,
// MaxHeartRate: req.MaxHeartRate,
// Duration: req.Duration,
// PeopleNum: req.PeopleNum,
// Evaluation: req.Evaluation,
// BeltAddresses: req.BeltAddrs,
// }
//
// if err := db.Create(&record).Error; err != nil {
// if isDuplicateKeyError(err) {
// c.JSON(http.StatusConflict, gin.H{"error": "训练记录已存在"})
// return
// }
// c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
// return
// }
//
// c.JSON(http.StatusCreated, record)
//}
// 统一时间处理函数
func parseTimestamp(ms int64) (time.Time, error) {
if ms < 1609459200000 { // 2021-01-01 00:00:00 UTC
return time.Time{}, fmt.Errorf("无效的时间戳")
}
return time.Unix(0, ms*int64(time.Millisecond)).UTC(), nil
}
// 更新心率数据处理
func processHeartRateData(req *HeartRateUploadRequest) ([]HeartRate, error) {
var rates []HeartRate
for _, d := range req.Data {
t, err := parseTimestamp(d.Timestamp)
if err != nil {
return nil, fmt.Errorf("无效的时间戳: %d", d.Timestamp)
}
rates = append(rates, HeartRate{
TrainID: req.TrainID,
BeltAddr: d.BeltAddr,
Time: t,
Value: d.Value,
LastValue: d.LastValue,
})
}
// 按时间排序
sort.Slice(rates, func(i, j int) bool {
return rates[i].Time.Before(rates[j].Time)
})
return rates, nil
}
// 更新自动迁移逻辑
func autoMigrate(db *gorm.DB) {
db.Set("gorm:table_options", " comment '训练记录表'").AutoMigrate(&TrainingRecord{})
db.Set("gorm:table_options", " comment '心率数据表'").AutoMigrate(&HeartRate{})
// 创建复合索引
db.Exec(`
CREATE INDEX IF NOT EXISTS idx_heart_rates_main
ON heart_rates (train_id, belt_addr, time DESC)
`)
}
var db *gorm.DB
const (
defaultDB = "postgres" // 用于创建新数据库的默认数据库
)
func initDB() {
// 解析原始DSN
dsn := "host=localhost user=postgres password=root dbname=training port=5432 sslmode=disable"
parsedDSN, err := url.Parse(dsn)
if err != nil {
log.Fatal("Invalid DSN:", err)
}
query := parsedDSN.Query()
query.Set("dbname", defaultDB)
query.Set("sslmode", "disable")
query.Set("port", "5432")
query.Set("host", "localhost")
query.Set("user", "postgres")
query.Set("password", "root")
defaultDSN := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s",
query.Get("host"),
query.Get("user"),
query.Get("password"),
query.Get("dbname"),
query.Get("port"),
query.Get("sslmode"),
)
//defaultDSNUrl, _ := url.Parse(defaultDSN)
// 提取数据库名称
dbName := "training"
// 第一步:尝试连接目标数据库
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
// 检查是否是数据库不存在的错误PostgreSQL错误码3D000
if isDatabaseNotExistError(err) {
log.Printf("Database %q does not exist, attempting to create...", dbName)
createDatabase(defaultDSN, dbName)
} else {
log.Fatal("Database connection failed:", err)
}
} else {
log.Printf("Database %q already exists", dbName)
sqlDB, _ := db.DB()
sqlDB.Close()
}
// 再次连接目标数据库
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
CreateBatchSize: 1000,
})
if err != nil {
log.Fatal("Final database connection failed:", err)
}
// 自动迁移表结构
autoMigrate(db)
}
// 检查是否为数据库不存在错误
func isDatabaseNotExistError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return containsErrorCode(errStr, "3D000")
}
func containsErrorCode(errStr, code string) bool {
// 使用正则表达式检查错误字符串中是否包含指定的错误码
re := regexp.MustCompile(`\b` + code + `\b`)
return re.MatchString(errStr)
}
// 创建新数据库
func createDatabase(parsedDSN string, dbName string) {
// 使用默认数据库连接
//parsedDSN.Path = "/" + defaultDB
defaultDSN := parsedDSN
// 创建数据库
db, err := gorm.Open(postgres.Open(defaultDSN), &gorm.Config{})
if err != nil {
log.Fatal("Connect to default database failed:", err)
}
// 需要超级用户权限才能创建数据库
createSQL := fmt.Sprintf("CREATE DATABASE \"%s\"", dbName)
if err := db.Exec(createSQL).Error; err != nil {
log.Fatal("Create database failed:", err)
}
sqlDB, _ := db.DB()
sqlDB.Close()
log.Printf("Database %q created successfully", dbName)
}
// 自动迁移表结构
//func autoMigrate(db *gorm.DB) {
// err := db.AutoMigrate(
// &TrainingRecord{},
// &HeartRate{},
// )
// if err != nil {
// log.Fatal("Auto migrate failed:", err)
// }
//
// // 添加索引(生产环境建议使用迁移工具)
// db.Exec("CREATE INDEX IF NOT EXISTS idx_heart_rates_train_time ON heart_rates (train_id, time)")
// log.Println("Database schema initialized successfully")
//}
// ...保持之前的import和结构体定义不变
//
// func initDB() {
// dsn := "host=localhost user=postgres password=root dbname=training port=5432 sslmode=disable"
// var err error
// db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
// CreateBatchSize: 1000,
// })
// if err != nil {
// log.Fatal("Failed to connect to database:", err)
// }
//
// // 自动迁移(生产环境建议使用迁移工具)
// db.AutoMigrate(&TrainingRecord{}, &HeartRate{})
// }
func setupRouter() *gin.Engine {
r := gin.Default()
// 训练记录路由组
trainingGroup := r.Group("/api/trainings")
{
trainingGroup.POST("create", createTraining)
trainingGroup.POST("/:id/heart-rates", uploadHeartRates)
//trainingGroup.GET("/:id/report", getTrainingReport)
//trainingGroup.GET("/:id/time-series", getTimeSeriesData)
}
return r
}
// 更新处理函数
func createTraining(c *gin.Context) {
var req CreateTrainingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
record := TrainingRecord{
TID: req.TID,
StartTimestamp: req.StartTime,
EndTimestamp: req.EndTime,
Name: req.Name,
MaxHeartRate: req.MaxHeartRate,
Duration: req.Duration,
PeopleNum: req.PeopleNum,
Evaluation: req.Evaluation,
BeltAddresses: req.BeltAddrs,
}
if err := db.Create(&record).Error; err != nil {
if isDuplicateKeyError(err) {
c.JSON(http.StatusConflict, gin.H{"error": "训练记录已存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, record)
}
// 判断是否为唯一键冲突错误
func isDuplicateKeyError(err error) bool {
if err == nil {
return false
}
var pqErr *pq.Error
if errors.As(err, &pqErr) {
return pqErr.Code == "23505" // 唯一键冲突错误码
}
return false
}
// 更新上传心率处理
func uploadHeartRates(c *gin.Context) {
var req HeartRateUploadRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 时区处理接受时区参数默认为UTC
loc := time.UTC
if tz := c.Query("tz"); tz != "" {
if l, err := time.LoadLocation(tz); err == nil {
loc = l
}
}
heartRates := make([]HeartRate, 0, len(req.Data))
for _, d := range req.Data {
t, err := parseTimestamp(d.Timestamp)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("invalid timestamp: %v", d.Timestamp),
})
return
}
// 转换为指定时区
t = t.In(loc)
heartRates = append(heartRates, HeartRate{
TrainID: req.TrainID,
BeltAddr: d.BeltAddr,
Time: t,
Value: d.Value,
LastValue: d.LastValue,
})
}
// 批量插入优化
err := db.Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "train_id"}, {Name: "time"}},
DoNothing: true,
},
clause.Returning{},
).CreateInBatches(heartRates, 1000).Error
if err != nil {
log.Printf("Batch insert error: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据存储失败"})
return
}
c.JSON(http.StatusCreated, gin.H{
"message": fmt.Sprintf("成功存储 %d 条心率数据", len(heartRates)),
"timezone": loc.String(),
})
}
// 获取训练报告
func getTrainingReport(c *gin.Context) {
trainID := parseUint(c.Param("id"))
threshold := c.DefaultQuery("threshold", "120") // 默认危险阈值120
var report TrainingReport
report.BeltStats = make(map[string]Stats)
// 获取基础统计信息
baseQuery := db.Model(&HeartRate{}).Where("train_id = ?", trainID)
// 整体平均和最大心率
baseQuery.Select("AVG(value) as avg, MAX(value) as max").
Row().Scan(&report.AvgHeartRate, &report.MaxHeartRate)
// 危险时长计算假设5秒一个数据点
var dangerCount int64
baseQuery.Where("value >= ?", threshold).Count(&dangerCount)
report.DangerSeconds = int(dangerCount) * 5
// 各腰带统计
rows, err := db.Model(&HeartRate{}).
Select("belt_addr, AVG(value) as avg, MAX(value) as max, COUNT(*) filter (where value >= ?) as danger", threshold).
Where("train_id = ?", trainID).
Group("belt_addr").
Rows()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer rows.Close()
for rows.Next() {
var addr string
var stat Stats
rows.Scan(&addr, &stat.Avg, &stat.Max, &stat.DangerCount)
report.BeltStats[addr] = stat
}
report.TrainID = trainID
c.JSON(http.StatusOK, report)
}
// 获取时间序列数据
func getTimeSeriesData(c *gin.Context) {
trainID := parseUint(c.Param("id"))
interval := c.DefaultQuery("interval", "1m") // 默认1分钟间隔
var points []TimePoint
err := db.Raw(`
SELECT
date_trunc(?, time) as time,
AVG(value) as avg_value,
COUNT(*) FILTER (WHERE value >= 120) as danger_count
FROM heart_rates
WHERE train_id = ?
GROUP BY 1
ORDER BY 1
`, interval, trainID).Scan(&points).Error
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, points)
}
// 辅助函数字符串转uint
func parseUint(s string) uint {
var n uint
fmt.Sscanf(s, "%d", &n)
return n
}
// 补全main函数
func main() {
initDB()
r := setupRouter()
log.Fatal(r.Run(":8081"))
}