From 85ec3cba4a4fd1fd9ad0aa51a5479cd77d33538d Mon Sep 17 00:00:00 2001 From: laoboli <1293528695@qq.com> Date: Tue, 3 Feb 2026 15:43:42 +0800 Subject: [PATCH] feat: account. --- controllers/login.go | 117 +++++++++++++++++++++++++++++++++++++ controllers/step_train.go | 15 ++++- main.go | 1 + middleware/jwt_for_user.go | 42 +++++++++++++ models/step_train.go | 1 + models/user.go | 40 +++++++++++++ routes/routes.go | 7 ++- util/jwt.go | 56 ++++++++++++++++++ 8 files changed, 276 insertions(+), 3 deletions(-) create mode 100644 controllers/login.go create mode 100644 middleware/jwt_for_user.go create mode 100644 models/user.go create mode 100644 util/jwt.go diff --git a/controllers/login.go b/controllers/login.go new file mode 100644 index 0000000..17df234 --- /dev/null +++ b/controllers/login.go @@ -0,0 +1,117 @@ +package controllers + +import ( + "hr_receiver/config" + "hr_receiver/models" + "hr_receiver/util" + "net/http" + + "github.com/gin-gonic/gin" +) + +type LoginRequest struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` +} + +type RegisterRequest struct { + Username string `json:"username" form:"username"` + Password string `json:"password" form:"password"` +} + +type AuthResponse struct { + Token string `json:"token"` + User models.User `json:"user"` +} + +// Register 用户注册 +func Register(c *gin.Context) { + var req RegisterRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 检查用户名是否已存在 + var existingUser models.User + if result := config.DB.Where("username = ?", req.Username).First(&existingUser); result.Error == nil { + c.JSON(http.StatusConflict, gin.H{"error": "Username already exists"}) + return + } + + // 创建新用户 + user := models.User{ + Username: req.Username, + Password: req.Password, // BeforeCreate钩子会自动加密 + } + + if result := config.DB.Create(&user); result.Error != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) + return + } + + // 生成Token + token, err := util.GenerateToken(user.ID, user.Username) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) + return + } + + c.JSON(http.StatusCreated, AuthResponse{ + Token: token, + User: user, + }) +} + +// Login 用户登录 +func Login(c *gin.Context) { + var req LoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 查找用户 + var user models.User + result := config.DB.Where("username = ?", req.Username).First(&user) + + if result.Error != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid username or password"}) + return + } + + // 验证密码 + if !user.CheckPassword(req.Password) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid username or password"}) + return + } + + // 生成JWT Token + token, err := util.GenerateToken(user.ID, user.Username) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"}) + return + } + + c.JSON(http.StatusOK, AuthResponse{ + Token: token, + User: user, + }) +} + +// GetProfile 获取用户信息(需要认证) +func GetProfile(c *gin.Context) { + userID, exists := c.Get("userID") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + var user models.User + if result := config.DB.First(&user, userID); result.Error != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + + c.JSON(http.StatusOK, user) +} diff --git a/controllers/step_train.go b/controllers/step_train.go index a7b9471..411abea 100644 --- a/controllers/step_train.go +++ b/controllers/step_train.go @@ -33,6 +33,12 @@ func (tc *StepTrainingController) CreateTrainingRecord(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + username, exists := c.Get("username") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "无法获取用户信息,请重新登录"}) + return + } + record.Username = username.(string) // 使用事务保存数据[4](@ref) err := tc.DB.Transaction(func(tx *gorm.DB) error { @@ -123,6 +129,11 @@ func (tc *StepTrainingController) GetTrainingRecords(c *gin.Context) { PageNum int `form:"pageNum,default=1"` // 页码,默认第一页 PageSize int `form:"pageSize,default=10"` // 每页数量,默认10条 } + username, exists := c.Get("username") + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{"error": "无法获取用户信息,请重新登录"}) + return + } var params PaginationParams if err := c.ShouldBindQuery(¶ms); err != nil { @@ -147,13 +158,13 @@ func (tc *StepTrainingController) GetTrainingRecords(c *gin.Context) { ) // 获取总记录数 - if err := tc.DB.Model(&models.StepTrainRecord{}).Count(&totalRows).Error; err != nil { + if err := tc.DB.Model(&models.StepTrainRecord{}).Where("username = ?", username).Count(&totalRows).Error; err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "获取记录总数失败"}) return } // 查询分页数据(按开始时间倒序排列) - result := tc.DB. + result := tc.DB.Where("username = ?", username). Order("start_time DESC"). // 按开始时间倒序 Offset(offset). Limit(params.PageSize). diff --git a/main.go b/main.go index 6ee59fa..d857126 100644 --- a/main.go +++ b/main.go @@ -24,6 +24,7 @@ func main() { &models.StepHeartRate{}, &models.StepStrideFreq{}, &models.RegressionResult{}, + &models.User{}, ) // 启动服务 diff --git a/middleware/jwt_for_user.go b/middleware/jwt_for_user.go new file mode 100644 index 0000000..0b06f19 --- /dev/null +++ b/middleware/jwt_for_user.go @@ -0,0 +1,42 @@ +package middleware + +import ( + "hr_receiver/util" + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +func JWTAuth() gin.HandlerFunc { + return func(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header required"}) + c.Abort() + return + } + + // Bearer Token格式 + parts := strings.SplitN(authHeader, " ", 2) + if !(len(parts) == 2 && parts[0] == "Bearer") { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header format must be Bearer {token}"}) + c.Abort() + return + } + + // 解析Token + claims, err := util.ParseToken(parts[1]) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"}) + c.Abort() + return + } + + // 将用户信息存入上下文 + c.Set("userID", claims.UserID) + c.Set("username", claims.Username) + + c.Next() + } +} diff --git a/models/step_train.go b/models/step_train.go index 1236b97..f4ad318 100644 --- a/models/step_train.go +++ b/models/step_train.go @@ -25,6 +25,7 @@ type StepHeartRate struct { // 对应Flutter的TrainRecord结构 type StepTrainRecord struct { gorm.Model + Username string `gorm:"size:50" json:"username"` // 对应Dart的tid字段 TrainId uint `gorm:"uniqueIndex" json:"tid"` // 对应Dart的tid字段 StartTime int64 `gorm:"type:bigint" json:"time"` // 开始时间戳 EndTime int64 `gorm:"type:bigint" json:"endTime"` // 结束时间戳[3](@ref) diff --git a/models/user.go b/models/user.go new file mode 100644 index 0000000..d9c2c94 --- /dev/null +++ b/models/user.go @@ -0,0 +1,40 @@ +package models + +import ( + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +type User struct { + ID uint `gorm:"primaryKey" json:"id"` + Username string `gorm:"uniqueIndex;not null" json:"username"` + Email string `gorm:"uniqueIndex;" json:"email"` + Phone string `gorm:"uniqueIndex;" json:"phone"` + Password string `gorm:"not null" json:"-"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +// HashPassword 密码加密 +func (u *User) HashPassword(password string) error { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14) + if err != nil { + return err + } + u.Password = string(bytes) + return nil +} + +// CheckPassword 验证密码 +func (u *User) CheckPassword(password string) bool { + err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) + return err == nil +} + +// BeforeCreate 创建前钩子 +func (u *User) BeforeCreate(tx *gorm.DB) (err error) { + if u.Password != "" { + return u.HashPassword(u.Password) + } + return nil +} diff --git a/routes/routes.go b/routes/routes.go index b8e8aaa..5b652ae 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -23,7 +23,7 @@ func SetupRouter() *gin.Engine { records.POST("/analysis-by-ai", trainingController.AnalyzeByAI) // 可扩展其他路由:GET, PUT, DELETE等 } - steps := v1.Group("/step") //.Use(middleware.AuthMiddleware()) + steps := v1.Group("/step").Use(middleware.JWTAuth()) { steps.POST("", stepTrainController.CreateTrainingRecord) steps.GET("train-records", stepTrainController.GetTrainingRecords) @@ -31,6 +31,11 @@ func SetupRouter() *gin.Engine { steps.GET("train-rank/:trainId", stepTrainController.GetTrainingRank) // 可扩展其他路由:GET, PUT, DELETE等 } + public := v1.Group("") + { + public.POST("/register", controllers.Register) + public.POST("/login", controllers.Login) + } auth := v1.Group("/auth") { auth.GET("/token", func(c *gin.Context) { diff --git a/util/jwt.go b/util/jwt.go new file mode 100644 index 0000000..fd8571f --- /dev/null +++ b/util/jwt.go @@ -0,0 +1,56 @@ +package util + +import ( + "errors" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var ApiSecret = "your-super-secret-key" // 预共享密钥 +type Claims struct { + UserID uint `json:"user_id"` + Username string `json:"username"` + jwt.RegisteredClaims +} + +// GenerateToken 生成JWT Token +func GenerateToken(userID uint, username string) (string, error) { + expirationTime := time.Now().Add(24 * 30 * time.Hour) // Token有效期24小时 + //expirationTime := time.Now().Add(1 * time.Second) // Token有效期24小时 + + claims := &Claims{ + UserID: userID, + Username: username, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(expirationTime), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: "your-app-name", + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte(ApiSecret)) + + return tokenString, err +} + +// ParseToken 解析JWT Token +func ParseToken(tokenStr string) (*Claims, error) { + claims := &Claims{} + + token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ApiSecret), nil + }) + + if err != nil { + return nil, err + } + + if !token.Valid { + return nil, errors.New("invalid token") + } + + return claims, nil +}