feat: mqtt receive.

This commit is contained in:
2026-04-28 15:29:16 +08:00
parent 51871c352a
commit 2464617599
9 changed files with 2820 additions and 2 deletions
+300
View File
@@ -0,0 +1,300 @@
package mqtt
import (
"crypto/tls"
"fmt"
"hr_receiver/config"
"hr_receiver/models"
whgw_hrpb "hr_receiver/proto"
"log"
"strings"
"time"
mqtt "github.com/eclipse/paho.mqtt.golang"
"google.golang.org/protobuf/proto"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
const (
defaultQueueSize = 2048
defaultWorkers = 4
)
type Listener struct {
db *gorm.DB
cfg config.MQTTConfig
client mqtt.Client
writeCh chan interface{}
}
func Start(db *gorm.DB, cfg config.MQTTConfig) error {
if !cfg.Enabled {
log.Println("mqtt listener disabled")
return nil
}
if err := validateConfig(cfg); err != nil {
return err
}
listener := &Listener{
db: db,
cfg: cfg,
writeCh: make(chan interface{}, defaultQueueSize),
}
for i := 0; i < defaultWorkers; i++ {
go listener.writeWorker()
}
if err := listener.connect(); err != nil {
return err
}
return nil
}
func validateConfig(cfg config.MQTTConfig) error {
if cfg.Host == "" {
return fmt.Errorf("missing config: mqtt.host")
}
if cfg.Port == 0 {
return fmt.Errorf("missing config: mqtt.port")
}
if cfg.Region == "" {
return fmt.Errorf("missing config: mqtt.region")
}
if cfg.ClientIDPrefix == "" {
return fmt.Errorf("missing config: mqtt.client_id_prefix")
}
return nil
}
func (l *Listener) connect() error {
opts := mqtt.NewClientOptions()
scheme := "tcp"
if l.cfg.UseTLS {
scheme = "ssl"
opts.SetTLSConfig(&tls.Config{MinVersion: tls.VersionTLS12})
}
broker := fmt.Sprintf("%s://%s:%d", scheme, l.cfg.Host, l.cfg.Port)
opts.AddBroker(broker)
opts.SetClientID(fmt.Sprintf("%s-%d", l.cfg.ClientIDPrefix, time.Now().UnixNano()))
opts.SetUsername(l.cfg.Username)
opts.SetPassword(l.cfg.Password)
opts.SetKeepAlive(60 * time.Second)
opts.SetAutoReconnect(true)
opts.SetConnectRetry(true)
opts.SetConnectRetryInterval(5 * time.Second)
opts.SetOnConnectHandler(func(client mqtt.Client) {
if err := l.subscribe(client); err != nil {
log.Printf("mqtt subscribe failed: %v", err)
return
}
log.Printf("mqtt connected to %s", broker)
})
opts.SetConnectionLostHandler(func(client mqtt.Client, err error) {
log.Printf("mqtt connection lost: %v", err)
})
opts.SetDefaultPublishHandler(l.handleMessage)
l.client = mqtt.NewClient(opts)
token := l.client.Connect()
if !token.WaitTimeout(15 * time.Second) {
return fmt.Errorf("mqtt connect timeout")
}
return token.Error()
}
func (l *Listener) subscribe(client mqtt.Client) error {
topics := []string{
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/hr", l.cfg.Region),
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/step", l.cfg.Region),
fmt.Sprintf("/whgw/v2/region/%s/gateway/+/status", l.cfg.Region),
}
for _, topic := range topics {
token := client.Subscribe(topic, byte(l.cfg.QoS), l.handleMessage)
if !token.WaitTimeout(10 * time.Second) {
return fmt.Errorf("mqtt subscribe timeout for topic %s", topic)
}
if err := token.Error(); err != nil {
return fmt.Errorf("mqtt subscribe topic %s: %w", topic, err)
}
log.Printf("mqtt subscribed: %s", topic)
}
return nil
}
func (l *Listener) handleMessage(_ mqtt.Client, msg mqtt.Message) {
defer func() {
if r := recover(); r != nil {
log.Printf("mqtt message handling panic, topic=%s err=%v", msg.Topic(), r)
}
}()
if len(msg.Payload()) == 0 {
return
}
var packet whgw_hrpb.GatewaySlaveOutCloudMasterInMsg
if err := proto.Unmarshal(msg.Payload(), &packet); err != nil {
log.Printf("mqtt payload parse failed, topic=%s err=%v", msg.Topic(), err)
return
}
now := time.Now().UnixMilli()
switch payload := packet.Choice.(type) {
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfHrMeasurement:
record := buildHeartRateRecord(payload.NtfHrMeasurement, msg.Topic(), now)
l.enqueue(&record)
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfStepCountMeasurement:
record := buildStepCountRecord(payload.NtfStepCountMeasurement, msg.Topic(), now)
l.enqueue(&record)
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfGatewayStatus:
record := buildGatewayStatusRecord(payload.NtfGatewayStatus, msg.Topic(), now)
l.enqueue(&record)
default:
log.Printf("mqtt payload ignored, unsupported type on topic=%s", msg.Topic())
}
}
func (l *Listener) enqueue(record interface{}) {
select {
case l.writeCh <- record:
default:
log.Printf("mqtt write queue full, dropping record of type %T", record)
}
}
func (l *Listener) writeWorker() {
for record := range l.writeCh {
func() {
defer func() {
if r := recover(); r != nil {
log.Printf("mqtt record persist panic, type=%T err=%v", record, r)
}
}()
if err := l.db.Clauses(clause.OnConflict{DoNothing: true}).Create(record).Error; err != nil {
log.Printf("mqtt record persist failed, type=%T err=%v", record, err)
}
}()
}
}
func buildHeartRateRecord(measurement *whgw_hrpb.HrMeasurement, topic string, now int64) models.MqttHeartRateRecord {
regionID := measurement.GetGatewayInfo().GetRegionId()
if regionID == 0 {
regionID = parseRegionFromTopic(topic)
}
gatewayMAC := formatMAC(measurement.GetGatewayInfo().GetGatewayMac())
packet := measurement.GetHrPacket()
rssi, snr := parsePacketStatus(measurement.GetPacketStatus())
beltAddr := fmt.Sprintf("%d-%d", regionID, packet.GetId())
return models.MqttHeartRateRecord{
Identifier: fmt.Sprintf("hr:%d:%s:%d:%d", regionID, gatewayMAC, packet.GetId(), packet.GetPacketNum()),
Topic: topic,
RegionID: regionID,
GatewayMAC: gatewayMAC,
BandID: packet.GetId(),
BeltAddr: beltAddr,
PacketNum: packet.GetPacketNum(),
HeartRate: int(packet.GetHr()),
HrConfidence: int(packet.GetStatus().GetHrConfidence().Number()),
IsActive: packet.GetStatus().GetIsActive(),
IsOnSkin: packet.GetStatus().GetIsOnSkin(),
Battery: packet.GetStatus().GetBattery(),
SignalRSSINeg: rssi,
SNR: snr,
HubBusID: measurement.GetHubInfo().GetBusId(),
HubSubDevID: measurement.GetHubInfo().GetSubDevId(),
ReceivedAt: now,
}
}
func buildStepCountRecord(measurement *whgw_hrpb.StepCountMeasurement, topic string, now int64) models.MqttStepCountRecord {
regionID := measurement.GetGatewayInfo().GetRegionId()
if regionID == 0 {
regionID = parseRegionFromTopic(topic)
}
gatewayMAC := formatMAC(measurement.GetGatewayInfo().GetGatewayMac())
packet := measurement.GetStepCountPacket()
rssi, snr := parsePacketStatus(measurement.GetPacketStatus())
beltAddr := fmt.Sprintf("%d-%d", regionID, packet.GetId())
return models.MqttStepCountRecord{
Identifier: fmt.Sprintf("step:%d:%s:%d:%d", regionID, gatewayMAC, packet.GetId(), packet.GetPacketNum()),
Topic: topic,
RegionID: regionID,
GatewayMAC: gatewayMAC,
BandID: packet.GetId(),
BeltAddr: beltAddr,
PacketNum: packet.GetPacketNum(),
StepCount: packet.GetStepCount(),
SignalRSSINeg: rssi,
SNR: snr,
HubBusID: measurement.GetHubInfo().GetBusId(),
HubSubDevID: measurement.GetHubInfo().GetSubDevId(),
ReceivedAt: now,
}
}
func buildGatewayStatusRecord(status *whgw_hrpb.GatewayStatus, topic string, now int64) models.MqttGatewayStatusRecord {
regionID := status.GetInfo().GetRegionId()
if regionID == 0 {
regionID = parseRegionFromTopic(topic)
}
gatewayMAC := formatMAC(status.GetInfo().GetGatewayMac())
return models.MqttGatewayStatusRecord{
Identifier: fmt.Sprintf("gateway:%d:%s:%d:%d:%d", regionID, gatewayMAC, status.GetStat().GetBootCount(), status.GetStat().GetUptimeMs(), status.GetStat().GetRxCount()),
Topic: topic,
RegionID: regionID,
GatewayMAC: gatewayMAC,
BootCount: status.GetStat().GetBootCount(),
UptimeMs: status.GetStat().GetUptimeMs(),
DurationMsSinceLastPacket: status.GetStat().GetDurationMsSinceLastPacket(),
RxCount: status.GetStat().GetRxCount(),
BatteryVoltageMV: status.GetStat().GetBatteryInfo().GetVoltageMv(),
BatterySOCPercentage: status.GetStat().GetBatteryInfo().GetSocPercentage(),
ChargingRatePercentage: status.GetStat().GetBatteryInfo().GetChargingRatePercentage(),
ReceivedAt: now,
}
}
func parsePacketStatus(status *whgw_hrpb.IPacketStatus) (float64, float64) {
if status == nil {
return 0, 0
}
if parsed := status.GetParsed(); parsed != nil {
return float64(parsed.GetSignalRssiNeg()), float64(parsed.GetSnrPkt())
}
if raw := status.GetRaw(); raw != nil {
return -float64(raw.GetSignalRssiX2Neg()) / 2, float64(raw.GetSnrPktX4()) / 4
}
return 0, 0
}
func formatMAC(data []byte) string {
if len(data) == 0 {
return ""
}
parts := make([]string, 0, len(data))
for _, b := range data {
parts = append(parts, fmt.Sprintf("%02x", b))
}
return strings.Join(parts, ":")
}
func parseRegionFromTopic(topic string) uint32 {
parts := strings.Split(topic, "/")
for i := 0; i < len(parts)-1; i++ {
if parts[i] == "region" {
var region uint32
if _, err := fmt.Sscanf(parts[i+1], "%d", &region); err == nil {
return region
}
}
}
return 0
}