274 lines
7.6 KiB
Go
274 lines
7.6 KiB
Go
package mqtt
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"hr_receiver/config"
|
|
"hr_receiver/models"
|
|
"log"
|
|
"sync"
|
|
"time"
|
|
|
|
mqtt "github.com/eclipse/paho.mqtt.golang"
|
|
"github.com/gorilla/websocket"
|
|
"google.golang.org/protobuf/proto"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
whgw_hrpb "hr_receiver/proto"
|
|
)
|
|
|
|
type DebugStatus struct {
|
|
Active bool `json:"active"`
|
|
ClientConnected bool `json:"clientConnected"`
|
|
PersistToDatabase bool `json:"persistToDatabase"`
|
|
Region string `json:"region"`
|
|
SubscriberCount int `json:"subscriberCount"`
|
|
}
|
|
|
|
type DebugEvent struct {
|
|
CardKey string `json:"cardKey"`
|
|
Kind string `json:"kind"`
|
|
RegionID uint32 `json:"regionId"`
|
|
ReceivedAt int64 `json:"receivedAt"`
|
|
Topic string `json:"topic"`
|
|
HeartRate *models.MqttHeartRateRecord `json:"heartRate,omitempty"`
|
|
StepCount *models.MqttStepCountRecord `json:"stepCount,omitempty"`
|
|
GatewayStatus *models.MqttGatewayStatusRecord `json:"gatewayStatus,omitempty"`
|
|
}
|
|
|
|
type DebugService struct {
|
|
cfg config.MQTTConfig
|
|
client mqtt.Client
|
|
db *gorm.DB
|
|
mu sync.RWMutex
|
|
persistToDatabase bool
|
|
subscribers map[*websocket.Conn]struct{}
|
|
active bool
|
|
}
|
|
|
|
var globalDebugService *DebugService
|
|
|
|
func InitDebugService(db *gorm.DB, cfg config.MQTTConfig) {
|
|
globalDebugService = &DebugService{
|
|
cfg: cfg,
|
|
db: db,
|
|
subscribers: make(map[*websocket.Conn]struct{}),
|
|
}
|
|
}
|
|
|
|
func GetDebugService() *DebugService {
|
|
return globalDebugService
|
|
}
|
|
|
|
func (s *DebugService) Status() DebugStatus {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return DebugStatus{
|
|
Active: s.active,
|
|
ClientConnected: s.client != nil && s.client.IsConnected(),
|
|
PersistToDatabase: s.persistToDatabase,
|
|
Region: s.cfg.Region,
|
|
SubscriberCount: len(s.subscribers),
|
|
}
|
|
}
|
|
|
|
func (s *DebugService) Start(persistToDatabase bool) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if s.active && s.client != nil && s.client.IsConnected() {
|
|
s.persistToDatabase = persistToDatabase
|
|
return nil
|
|
}
|
|
if err := validateConfig(s.cfg); err != nil {
|
|
return err
|
|
}
|
|
|
|
client, err := s.connectLocked(persistToDatabase)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.client = client
|
|
s.persistToDatabase = persistToDatabase
|
|
s.active = true
|
|
return nil
|
|
}
|
|
|
|
func (s *DebugService) Stop() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.client != nil && s.client.IsConnected() {
|
|
s.client.Disconnect(250)
|
|
}
|
|
s.client = nil
|
|
s.active = false
|
|
s.persistToDatabase = false
|
|
}
|
|
|
|
func (s *DebugService) AddSubscriber(conn *websocket.Conn) {
|
|
s.mu.Lock()
|
|
s.subscribers[conn] = struct{}{}
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *DebugService) RemoveSubscriber(conn *websocket.Conn) {
|
|
s.mu.Lock()
|
|
delete(s.subscribers, conn)
|
|
s.mu.Unlock()
|
|
_ = conn.Close()
|
|
}
|
|
|
|
func (s *DebugService) connectLocked(persistToDatabase bool) (mqtt.Client, error) {
|
|
opts := mqtt.NewClientOptions()
|
|
scheme := "tcp"
|
|
if s.cfg.UseTLS {
|
|
scheme = "ssl"
|
|
opts.SetTLSConfig(&tls.Config{MinVersion: tls.VersionTLS12})
|
|
}
|
|
broker := fmt.Sprintf("%s://%s:%d", scheme, s.cfg.Host, s.cfg.Port)
|
|
opts.AddBroker(broker)
|
|
opts.SetClientID(fmt.Sprintf("%s-debug-%d", s.cfg.ClientIDPrefix, time.Now().UnixNano()))
|
|
opts.SetUsername(s.cfg.Username)
|
|
opts.SetPassword(s.cfg.Password)
|
|
opts.SetKeepAlive(60 * time.Second)
|
|
opts.SetAutoReconnect(false)
|
|
opts.SetConnectRetry(false)
|
|
opts.SetDefaultPublishHandler(s.handleMessage)
|
|
opts.SetOnConnectHandler(func(client mqtt.Client) {
|
|
if err := s.subscribe(client); err != nil {
|
|
log.Printf("mqtt debug subscribe failed: %v", err)
|
|
return
|
|
}
|
|
log.Printf("mqtt debug connected to %s persist=%v", broker, persistToDatabase)
|
|
})
|
|
opts.SetConnectionLostHandler(func(client mqtt.Client, err error) {
|
|
log.Printf("mqtt debug connection lost: %v", err)
|
|
s.mu.Lock()
|
|
if s.client == client {
|
|
s.client = nil
|
|
s.active = false
|
|
}
|
|
s.mu.Unlock()
|
|
})
|
|
|
|
client := mqtt.NewClient(opts)
|
|
token := client.Connect()
|
|
if !token.WaitTimeout(15 * time.Second) {
|
|
return nil, fmt.Errorf("mqtt debug connect timeout")
|
|
}
|
|
if err := token.Error(); err != nil {
|
|
return nil, err
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func (s *DebugService) subscribe(client mqtt.Client) error {
|
|
topics := []string{
|
|
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/hr", s.cfg.Region),
|
|
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/step", s.cfg.Region),
|
|
fmt.Sprintf("/whgw/v2/region/%s/gateway/+/status", s.cfg.Region),
|
|
}
|
|
for _, topic := range topics {
|
|
token := client.Subscribe(topic, byte(s.cfg.QoS), s.handleMessage)
|
|
if !token.WaitTimeout(10 * time.Second) {
|
|
return fmt.Errorf("mqtt debug subscribe timeout for topic %s", topic)
|
|
}
|
|
if err := token.Error(); err != nil {
|
|
return fmt.Errorf("mqtt debug subscribe topic %s: %w", topic, err)
|
|
}
|
|
log.Printf("mqtt debug subscribed: %s", topic)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *DebugService) handleMessage(_ mqtt.Client, msg mqtt.Message) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("mqtt debug handle panic topic=%s err=%v", msg.Topic(), r)
|
|
}
|
|
}()
|
|
if len(msg.Payload()) == 0 {
|
|
return
|
|
}
|
|
|
|
now := time.Now().UnixMilli()
|
|
var packet whgw_hrpb.GatewaySlaveOutCloudMasterInMsg
|
|
if err := proto.Unmarshal(msg.Payload(), &packet); err != nil {
|
|
log.Printf("mqtt debug payload parse failed topic=%s err=%v", msg.Topic(), err)
|
|
return
|
|
}
|
|
|
|
switch payload := packet.Choice.(type) {
|
|
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfHrMeasurement:
|
|
record := buildHeartRateRecord(payload.NtfHrMeasurement, msg.Topic(), now)
|
|
s.maybePersist(&record)
|
|
s.broadcast(DebugEvent{
|
|
CardKey: fmt.Sprintf("%d-%d", record.RegionID, record.BandID),
|
|
HeartRate: &record,
|
|
Kind: "heart_rate",
|
|
ReceivedAt: now,
|
|
RegionID: record.RegionID,
|
|
Topic: msg.Topic(),
|
|
})
|
|
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfStepCountMeasurement:
|
|
record := buildStepCountRecord(payload.NtfStepCountMeasurement, msg.Topic(), now)
|
|
s.maybePersist(&record)
|
|
s.broadcast(DebugEvent{
|
|
CardKey: fmt.Sprintf("%d-%d", record.RegionID, record.BandID),
|
|
Kind: "step_count",
|
|
ReceivedAt: now,
|
|
RegionID: record.RegionID,
|
|
StepCount: &record,
|
|
Topic: msg.Topic(),
|
|
})
|
|
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfGatewayStatus:
|
|
record := buildGatewayStatusRecord(payload.NtfGatewayStatus, msg.Topic(), now)
|
|
s.maybePersist(&record)
|
|
s.broadcast(DebugEvent{
|
|
CardKey: fmt.Sprintf("%d-%s", record.RegionID, record.GatewayMAC),
|
|
GatewayStatus: &record,
|
|
Kind: "gateway_status",
|
|
ReceivedAt: now,
|
|
RegionID: record.RegionID,
|
|
Topic: msg.Topic(),
|
|
})
|
|
default:
|
|
log.Printf("mqtt debug payload ignored topic=%s", msg.Topic())
|
|
}
|
|
}
|
|
|
|
func (s *DebugService) maybePersist(record interface{}) {
|
|
s.mu.RLock()
|
|
enabled := s.persistToDatabase
|
|
s.mu.RUnlock()
|
|
if !enabled {
|
|
return
|
|
}
|
|
if err := s.db.Clauses(clause.OnConflict{DoNothing: true}).Create(record).Error; err != nil {
|
|
log.Printf("mqtt debug persist failed type=%T err=%v", record, err)
|
|
}
|
|
}
|
|
|
|
func (s *DebugService) broadcast(event DebugEvent) {
|
|
payload, err := json.Marshal(event)
|
|
if err != nil {
|
|
log.Printf("mqtt debug marshal failed err=%v", err)
|
|
return
|
|
}
|
|
|
|
s.mu.RLock()
|
|
conns := make([]*websocket.Conn, 0, len(s.subscribers))
|
|
for conn := range s.subscribers {
|
|
conns = append(conns, conn)
|
|
}
|
|
s.mu.RUnlock()
|
|
|
|
for _, conn := range conns {
|
|
if err := conn.WriteMessage(websocket.TextMessage, payload); err != nil {
|
|
log.Printf("mqtt debug websocket send failed err=%v", err)
|
|
s.RemoveSubscriber(conn)
|
|
}
|
|
}
|
|
}
|