Files
hr_data_analyzer/controllers/login.go
T
2026-05-04 16:20:46 +08:00

170 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package controllers
import (
"hr_receiver/config"
"hr_receiver/models"
"hr_receiver/util"
"net/http"
"github.com/gin-gonic/gin"
)
type LoginRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
type RegisterRequest struct {
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Email *string `json:"email" form:"email"`
Phone *string `json:"phone" form:"phone"`
Role models.UserRole `json:"role" form:"role"`
FlavorType models.UserFlavorType `json:"flavorType" form:"flavorType"`
RegionIDs []uint32 `json:"regionIds" form:"regionIds"`
}
type AuthResponse struct {
Token string `json:"token"`
User models.User `json:"user"`
}
// @Summary 用户注册
// @Description 注册新用户返回JWT Token
// @Tags 认证
// @Accept json
// @Produce json
// @Param request body SwagRegisterRequest true "注册信息"
// @Success 201 {object} SwagAPIResponse "注册成功"
// @Failure 400 {object} SwagAPIResponse "请求参数错误"
// @Failure 409 {object} SwagAPIResponse "用户名已存在"
// @Router /register [post]
func Register(c *gin.Context) {
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 检查用户名是否已存在
var existingUser models.User
if result := config.DB.Where("username = ?", req.Username).First(&existingUser); result.Error == nil {
c.JSON(http.StatusConflict, gin.H{"error": "Username already exists"})
return
}
// 创建新用户
user := models.User{
Username: req.Username,
Email: req.Email,
Phone: req.Phone,
Password: req.Password, // BeforeCreate钩子会自动加密
Role: req.Role,
FlavorType: req.FlavorType,
Regions: buildUserRegionBindings(req.RegionIDs),
}
if result := config.DB.Create(&user); result.Error != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
// 生成Token
token, err := util.GenerateToken(&user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
c.JSON(http.StatusCreated, AuthResponse{
Token: token,
User: user,
})
}
// @Summary 用户登录
// @Description 用户名密码登录返回JWT Token和用户信息
// @Tags 认证
// @Accept json
// @Produce json
// @Param request body SwagLoginRequest true "登录信息"
// @Success 200 {object} SwagAPIResponse "登录成功"
// @Failure 400 {object} SwagAPIResponse "请求参数错误"
// @Failure 401 {object} SwagAPIResponse "用户名或密码错误"
// @Failure 403 {object} SwagAPIResponse "用户已禁用"
// @Router /login [post]
func Login(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 查找用户
var user models.User
result := config.DB.Preload("Regions").Where("username = ?", req.Username).First(&user)
if result.Error != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid username or password"})
return
}
// 验证密码
if !user.CheckPassword(req.Password) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid username or password"})
return
}
if !user.IsActive {
c.JSON(http.StatusForbidden, gin.H{"error": "User is disabled"})
return
}
// 生成JWT Token
token, err := util.GenerateToken(&user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
c.JSON(http.StatusOK, AuthResponse{
Token: token,
User: user,
})
}
// GetProfile 获取用户信息(需要认证)
func GetProfile(c *gin.Context) {
userID, exists := c.Get("userID")
if !exists {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
var user models.User
if result := config.DB.Preload("Regions").First(&user, userID); result.Error != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
return
}
c.JSON(http.StatusOK, user)
}
func buildUserRegionBindings(regionIDs []uint32) []models.UserRegionBinding {
if len(regionIDs) == 0 {
return nil
}
seen := make(map[uint32]struct{}, len(regionIDs))
regions := make([]models.UserRegionBinding, 0, len(regionIDs))
for _, regionID := range regionIDs {
if regionID == 0 {
continue
}
if _, exists := seen[regionID]; exists {
continue
}
seen[regionID] = struct{}{}
regions = append(regions, models.UserRegionBinding{RegionID: regionID})
}
return regions
}