feat: start record.

This commit is contained in:
2026-04-28 18:52:03 +08:00
parent 2464617599
commit 9a95130488
7 changed files with 251 additions and 15 deletions
+166 -6
View File
@@ -2,6 +2,7 @@ package mqtt
import (
"crypto/tls"
"encoding/json"
"fmt"
"hr_receiver/config"
"hr_receiver/models"
@@ -14,6 +15,7 @@ import (
"google.golang.org/protobuf/proto"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
const (
@@ -28,6 +30,16 @@ type Listener struct {
writeCh chan interface{}
}
type trainingSessionPayload struct {
Type string `json:"type"`
EventType string `json:"eventType"`
TestID string `json:"testId"`
RegionID string `json:"regionId"`
Flavor string `json:"flavor"`
AppName string `json:"appName"`
Timestamp string `json:"timestamp"`
}
func Start(db *gorm.DB, cfg config.MQTTConfig) error {
if !cfg.Enabled {
log.Println("mqtt listener disabled")
@@ -106,10 +118,16 @@ func (l *Listener) connect() error {
}
func (l *Listener) subscribe(client mqtt.Client) error {
topics := []string{
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/hr", l.cfg.Region),
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/step", l.cfg.Region),
fmt.Sprintf("/whgw/v2/region/%s/gateway/+/status", l.cfg.Region),
var topics []string
if l.cfg.EnableMeasurementSubscriptions {
topics = append(topics,
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/hr", l.cfg.Region),
fmt.Sprintf("/whgw/v2/region/%s/measurement/band/+/step", l.cfg.Region),
fmt.Sprintf("/whgw/v2/region/%s/gateway/+/status", l.cfg.Region),
)
}
if l.cfg.EnableTrainingEventSubscription {
topics = append(topics, "/whgw/v2/region/test/+/+")
}
for _, topic := range topics {
token := client.Subscribe(topic, byte(l.cfg.QoS), l.handleMessage)
@@ -134,14 +152,23 @@ func (l *Listener) handleMessage(_ mqtt.Client, msg mqtt.Message) {
if len(msg.Payload()) == 0 {
return
}
now := time.Now().UnixMilli()
var packet whgw_hrpb.GatewaySlaveOutCloudMasterInMsg
if isTrainingEventTopic(msg.Topic()) {
record, ok := buildTrainingSessionRecord(msg.Topic(), msg.Payload(), now)
if !ok {
return
}
l.enqueue(record)
return
}
if err := proto.Unmarshal(msg.Payload(), &packet); err != nil {
log.Printf("mqtt payload parse failed, topic=%s err=%v", msg.Topic(), err)
return
}
now := time.Now().UnixMilli()
switch payload := packet.Choice.(type) {
case *whgw_hrpb.GatewaySlaveOutCloudMasterInMsg_NtfHrMeasurement:
record := buildHeartRateRecord(payload.NtfHrMeasurement, msg.Topic(), now)
@@ -174,13 +201,48 @@ func (l *Listener) writeWorker() {
}
}()
if err := l.db.Clauses(clause.OnConflict{DoNothing: true}).Create(record).Error; err != nil {
if err := l.persistRecord(record); err != nil {
log.Printf("mqtt record persist failed, type=%T err=%v", record, err)
}
}()
}
}
func (l *Listener) persistRecord(record interface{}) error {
switch r := record.(type) {
case *models.MqttTrainingSessionRecord:
return l.persistTrainingSession(r)
default:
return l.db.Clauses(clause.OnConflict{DoNothing: true}).Create(record).Error
}
}
func (l *Listener) persistTrainingSession(record *models.MqttTrainingSessionRecord) error {
assignments := map[string]interface{}{
"topic": record.Topic,
"event_type": record.EventType,
"region_id": record.RegionID,
"flavor_type": record.FlavorType,
"raw_flavor": record.RawFlavor,
"app_name": record.AppName,
"published_at": record.PublishedAt,
"received_at": record.ReceivedAt,
"raw_payload": record.RawPayload,
"updated_at": time.Now(),
}
if record.StartedAt != nil {
assignments["started_at"] = *record.StartedAt
}
if record.EndedAt != nil {
assignments["ended_at"] = *record.EndedAt
}
return l.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "identifier"}},
DoUpdates: clause.Assignments(assignments),
}).Create(record).Error
}
func buildHeartRateRecord(measurement *whgw_hrpb.HrMeasurement, topic string, now int64) models.MqttHeartRateRecord {
regionID := measurement.GetGatewayInfo().GetRegionId()
if regionID == 0 {
@@ -298,3 +360,101 @@ func parseRegionFromTopic(topic string) uint32 {
}
return 0
}
func isTrainingEventTopic(topic string) bool {
return strings.HasPrefix(topic, "/whgw/v2/region/test/")
}
func buildTrainingSessionRecord(topic string, payload []byte, now int64) (*models.MqttTrainingSessionRecord, bool) {
var event trainingSessionPayload
if err := json.Unmarshal(payload, &event); err != nil {
log.Printf("mqtt training event parse failed, topic=%s err=%v", topic, err)
return nil, false
}
if event.Type != "mqtt_test" {
log.Printf("mqtt training event ignored, unsupported type topic=%s type=%s", topic, event.Type)
return nil, false
}
flavorType := normalizeFlavor(event.Flavor)
if flavorType != "heartrate" {
log.Printf("mqtt training event ignored, unsupported flavor topic=%s flavor=%s", topic, event.Flavor)
return nil, false
}
regionID := parseUint32(event.RegionID)
if regionID == 0 {
regionID = parseRegionFromTrainingTopic(topic)
}
publishedAt := parseRFC3339Milli(event.Timestamp)
if publishedAt == 0 {
publishedAt = now
}
record := &models.MqttTrainingSessionRecord{
Identifier: buildTrainingSessionIdentifier(flavorType, regionID, event.TestID),
Topic: topic,
TestID: event.TestID,
EventType: event.EventType,
RegionID: regionID,
FlavorType: flavorType,
RawFlavor: event.Flavor,
AppName: event.AppName,
PublishedAt: publishedAt,
ReceivedAt: now,
RawPayload: string(payload),
}
switch event.EventType {
case "start_test":
record.StartedAt = &publishedAt
case "stop_test":
record.EndedAt = &publishedAt
default:
log.Printf("mqtt training event ignored, unsupported event topic=%s event=%s", topic, event.EventType)
return nil, false
}
return record, true
}
func buildTrainingSessionIdentifier(flavorType string, regionID uint32, testID string) string {
return schema.NamingStrategy{}.IndexName(
"mqtt_training_session",
fmt.Sprintf("%s_%d_%s", flavorType, regionID, testID),
)
}
func normalizeFlavor(flavor string) string {
switch strings.ToLower(strings.TrimSpace(flavor)) {
case "hr", "heartrate":
return "heartrate"
default:
return strings.ToLower(strings.TrimSpace(flavor))
}
}
func parseUint32(value string) uint32 {
var result uint32
if _, err := fmt.Sscanf(strings.TrimSpace(value), "%d", &result); err == nil {
return result
}
return 0
}
func parseRegionFromTrainingTopic(topic string) uint32 {
parts := strings.Split(topic, "/")
if len(parts) >= 6 {
return parseUint32(parts[5])
}
return 0
}
func parseRFC3339Milli(value string) int64 {
if strings.TrimSpace(value) == "" {
return 0
}
t, err := time.Parse(time.RFC3339, value)
if err != nil {
return 0
}
return t.UnixMilli()
}