feat: mqtt receive.
This commit is contained in:
@@ -0,0 +1,300 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"hr_receiver/config"
|
||||
"hr_receiver/models"
|
||||
whgw_hrpb "hr_receiver/proto"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueueSize = 2048
|
||||
defaultWorkers = 4
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
db *gorm.DB
|
||||
cfg config.MQTTConfig
|
||||
client mqtt.Client
|
||||
writeCh chan interface{}
|
||||
}
|
||||
|
||||
func Start(db *gorm.DB, cfg config.MQTTConfig) error {
|
||||
if !cfg.Enabled {
|
||||
log.Println("mqtt listener disabled")
|
||||
return nil
|
||||
}
|
||||
if err := validateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listener := &Listener{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
writeCh: make(chan interface{}, defaultQueueSize),
|
||||
}
|
||||
|
||||
for i := 0; i < defaultWorkers; i++ {
|
||||
go listener.writeWorker()
|
||||
}
|
||||
|
||||
if err := listener.connect(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfig(cfg config.MQTTConfig) error {
|
||||
if cfg.Host == "" {
|
||||
return fmt.Errorf("missing config: mqtt.host")
|
||||
}
|
||||
if cfg.Port == 0 {
|
||||
return fmt.Errorf("missing config: mqtt.port")
|
||||
}
|
||||
if cfg.Region == "" {
|
||||
return fmt.Errorf("missing config: mqtt.region")
|
||||
}
|
||||
if cfg.ClientIDPrefix == "" {
|
||||
return fmt.Errorf("missing config: mqtt.client_id_prefix")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) connect() error {
|
||||
opts := mqtt.NewClientOptions()
|
||||
scheme := "tcp"
|
||||
if l.cfg.UseTLS {
|
||||
scheme = "ssl"
|
||||
opts.SetTLSConfig(&tls.Config{MinVersion: tls.VersionTLS12})
|
||||
}
|
||||
broker := fmt.Sprintf("%s://%s:%d", scheme, l.cfg.Host, l.cfg.Port)
|
||||
opts.AddBroker(broker)
|
||||
opts.SetClientID(fmt.Sprintf("%s-%d", l.cfg.ClientIDPrefix, time.Now().UnixNano()))
|
||||
opts.SetUsername(l.cfg.Username)
|
||||
opts.SetPassword(l.cfg.Password)
|
||||
opts.SetKeepAlive(60 * time.Second)
|
||||
opts.SetAutoReconnect(true)
|
||||
opts.SetConnectRetry(true)
|
||||
opts.SetConnectRetryInterval(5 * time.Second)
|
||||
opts.SetOnConnectHandler(func(client mqtt.Client) {
|
||||
if err := l.subscribe(client); err != nil {
|
||||
log.Printf("mqtt subscribe failed: %v", err)
|
||||
return
|
||||
}
|
||||
log.Printf("mqtt connected to %s", broker)
|
||||
})
|
||||
opts.SetConnectionLostHandler(func(client mqtt.Client, err error) {
|
||||
log.Printf("mqtt connection lost: %v", err)
|
||||
})
|
||||
opts.SetDefaultPublishHandler(l.handleMessage)
|
||||
|
||||
l.client = mqtt.NewClient(opts)
|
||||
token := l.client.Connect()
|
||||
if !token.WaitTimeout(15 * time.Second) {
|
||||
return fmt.Errorf("mqtt connect timeout")
|
||||
}
|
||||
return token.Error()
|
||||
}
|
||||
|
||||
func (l *Listener) subscribe(client mqtt.Client) error {
|
||||
topics := []string{
|
||||
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/hr", l.cfg.Region),
|
||||
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/step", l.cfg.Region),
|
||||
fmt.Sprintf("/whgw/v2/region/%s/gateway/+/status", l.cfg.Region),
|
||||
}
|
||||
for _, topic := range topics {
|
||||
token := client.Subscribe(topic, byte(l.cfg.QoS), l.handleMessage)
|
||||
if !token.WaitTimeout(10 * time.Second) {
|
||||
return fmt.Errorf("mqtt subscribe timeout for topic %s", topic)
|
||||
}
|
||||
if err := token.Error(); err != nil {
|
||||
return fmt.Errorf("mqtt subscribe topic %s: %w", topic, err)
|
||||
}
|
||||
log.Printf("mqtt subscribed: %s", topic)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) handleMessage(_ mqtt.Client, msg mqtt.Message) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("mqtt message handling panic, topic=%s err=%v", msg.Topic(), r)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(msg.Payload()) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var packet whgw_hrpb.GatewaySlaveOutCloudMasterInMsg
|
||||
if err := proto.Unmarshal(msg.Payload(), &packet); err != nil {
|
||||
log.Printf("mqtt payload parse failed, topic=%s err=%v", msg.Topic(), err)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
switch payload := packet.Choice.(type) {
|
||||
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfHrMeasurement:
|
||||
record := buildHeartRateRecord(payload.NtfHrMeasurement, msg.Topic(), now)
|
||||
l.enqueue(&record)
|
||||
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfStepCountMeasurement:
|
||||
record := buildStepCountRecord(payload.NtfStepCountMeasurement, msg.Topic(), now)
|
||||
l.enqueue(&record)
|
||||
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfGatewayStatus:
|
||||
record := buildGatewayStatusRecord(payload.NtfGatewayStatus, msg.Topic(), now)
|
||||
l.enqueue(&record)
|
||||
default:
|
||||
log.Printf("mqtt payload ignored, unsupported type on topic=%s", msg.Topic())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) enqueue(record interface{}) {
|
||||
select {
|
||||
case l.writeCh <- record:
|
||||
default:
|
||||
log.Printf("mqtt write queue full, dropping record of type %T", record)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) writeWorker() {
|
||||
for record := range l.writeCh {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("mqtt record persist panic, type=%T err=%v", record, r)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := l.db.Clauses(clause.OnConflict{DoNothing: true}).Create(record).Error; err != nil {
|
||||
log.Printf("mqtt record persist failed, type=%T err=%v", record, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func buildHeartRateRecord(measurement *whgw_hrpb.HrMeasurement, topic string, now int64) models.MqttHeartRateRecord {
|
||||
regionID := measurement.GetGatewayInfo().GetRegionId()
|
||||
if regionID == 0 {
|
||||
regionID = parseRegionFromTopic(topic)
|
||||
}
|
||||
gatewayMAC := formatMAC(measurement.GetGatewayInfo().GetGatewayMac())
|
||||
packet := measurement.GetHrPacket()
|
||||
rssi, snr := parsePacketStatus(measurement.GetPacketStatus())
|
||||
beltAddr := fmt.Sprintf("%d-%d", regionID, packet.GetId())
|
||||
|
||||
return models.MqttHeartRateRecord{
|
||||
Identifier: fmt.Sprintf("hr:%d:%s:%d:%d", regionID, gatewayMAC, packet.GetId(), packet.GetPacketNum()),
|
||||
Topic: topic,
|
||||
RegionID: regionID,
|
||||
GatewayMAC: gatewayMAC,
|
||||
BandID: packet.GetId(),
|
||||
BeltAddr: beltAddr,
|
||||
PacketNum: packet.GetPacketNum(),
|
||||
HeartRate: int(packet.GetHr()),
|
||||
HrConfidence: int(packet.GetStatus().GetHrConfidence().Number()),
|
||||
IsActive: packet.GetStatus().GetIsActive(),
|
||||
IsOnSkin: packet.GetStatus().GetIsOnSkin(),
|
||||
Battery: packet.GetStatus().GetBattery(),
|
||||
SignalRSSINeg: rssi,
|
||||
SNR: snr,
|
||||
HubBusID: measurement.GetHubInfo().GetBusId(),
|
||||
HubSubDevID: measurement.GetHubInfo().GetSubDevId(),
|
||||
ReceivedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func buildStepCountRecord(measurement *whgw_hrpb.StepCountMeasurement, topic string, now int64) models.MqttStepCountRecord {
|
||||
regionID := measurement.GetGatewayInfo().GetRegionId()
|
||||
if regionID == 0 {
|
||||
regionID = parseRegionFromTopic(topic)
|
||||
}
|
||||
gatewayMAC := formatMAC(measurement.GetGatewayInfo().GetGatewayMac())
|
||||
packet := measurement.GetStepCountPacket()
|
||||
rssi, snr := parsePacketStatus(measurement.GetPacketStatus())
|
||||
beltAddr := fmt.Sprintf("%d-%d", regionID, packet.GetId())
|
||||
|
||||
return models.MqttStepCountRecord{
|
||||
Identifier: fmt.Sprintf("step:%d:%s:%d:%d", regionID, gatewayMAC, packet.GetId(), packet.GetPacketNum()),
|
||||
Topic: topic,
|
||||
RegionID: regionID,
|
||||
GatewayMAC: gatewayMAC,
|
||||
BandID: packet.GetId(),
|
||||
BeltAddr: beltAddr,
|
||||
PacketNum: packet.GetPacketNum(),
|
||||
StepCount: packet.GetStepCount(),
|
||||
SignalRSSINeg: rssi,
|
||||
SNR: snr,
|
||||
HubBusID: measurement.GetHubInfo().GetBusId(),
|
||||
HubSubDevID: measurement.GetHubInfo().GetSubDevId(),
|
||||
ReceivedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func buildGatewayStatusRecord(status *whgw_hrpb.GatewayStatus, topic string, now int64) models.MqttGatewayStatusRecord {
|
||||
regionID := status.GetInfo().GetRegionId()
|
||||
if regionID == 0 {
|
||||
regionID = parseRegionFromTopic(topic)
|
||||
}
|
||||
gatewayMAC := formatMAC(status.GetInfo().GetGatewayMac())
|
||||
|
||||
return models.MqttGatewayStatusRecord{
|
||||
Identifier: fmt.Sprintf("gateway:%d:%s:%d:%d:%d", regionID, gatewayMAC, status.GetStat().GetBootCount(), status.GetStat().GetUptimeMs(), status.GetStat().GetRxCount()),
|
||||
Topic: topic,
|
||||
RegionID: regionID,
|
||||
GatewayMAC: gatewayMAC,
|
||||
BootCount: status.GetStat().GetBootCount(),
|
||||
UptimeMs: status.GetStat().GetUptimeMs(),
|
||||
DurationMsSinceLastPacket: status.GetStat().GetDurationMsSinceLastPacket(),
|
||||
RxCount: status.GetStat().GetRxCount(),
|
||||
BatteryVoltageMV: status.GetStat().GetBatteryInfo().GetVoltageMv(),
|
||||
BatterySOCPercentage: status.GetStat().GetBatteryInfo().GetSocPercentage(),
|
||||
ChargingRatePercentage: status.GetStat().GetBatteryInfo().GetChargingRatePercentage(),
|
||||
ReceivedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func parsePacketStatus(status *whgw_hrpb.IPacketStatus) (float64, float64) {
|
||||
if status == nil {
|
||||
return 0, 0
|
||||
}
|
||||
if parsed := status.GetParsed(); parsed != nil {
|
||||
return float64(parsed.GetSignalRssiNeg()), float64(parsed.GetSnrPkt())
|
||||
}
|
||||
if raw := status.GetRaw(); raw != nil {
|
||||
return -float64(raw.GetSignalRssiX2Neg()) / 2, float64(raw.GetSnrPktX4()) / 4
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func formatMAC(data []byte) string {
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
parts := make([]string, 0, len(data))
|
||||
for _, b := range data {
|
||||
parts = append(parts, fmt.Sprintf("%02x", b))
|
||||
}
|
||||
return strings.Join(parts, ":")
|
||||
}
|
||||
|
||||
func parseRegionFromTopic(topic string) uint32 {
|
||||
parts := strings.Split(topic, "/")
|
||||
for i := 0; i < len(parts)-1; i++ {
|
||||
if parts[i] == "region" {
|
||||
var region uint32
|
||||
if _, err := fmt.Sscanf(parts[i+1], "%d", ®ion); err == nil {
|
||||
return region
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
Reference in New Issue
Block a user