refactor: Update message handling to include timestamps and improve polling intervals in heart rate monitoring

This commit is contained in:
2025-06-10 18:41:55 +08:00
parent 3a15bd655e
commit 3d53f00ec9

65
main.py
View File

@ -27,6 +27,7 @@ from threading import Thread
from time import sleep from time import sleep
from pydantic import BaseModel, computed_field from pydantic import BaseModel, computed_field
from datetime import datetime, timedelta from datetime import datetime, timedelta
import struct
import awkward as ak import awkward as ak
from awkward import Array as AwkwardArray, Record as AwkwardRecord from awkward import Array as AwkwardArray, Record as AwkwardRecord
from app.model import AlgoReport from app.model import AlgoReport
@ -35,6 +36,7 @@ from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from aiomqtt import Client as MqttClient, Message as MqttMessage from aiomqtt import Client as MqttClient, Message as MqttMessage
from app.proto.hr_packet import HrPacket, HrOnlyPacket, HrPpgPacket, HrConfidence from app.proto.hr_packet import HrPacket, HrOnlyPacket, HrPpgPacket, HrConfidence
import betterproto
MQTT_BROKER: Final[str] = "weihua-iot.cn" MQTT_BROKER: Final[str] = "weihua-iot.cn"
@ -55,14 +57,17 @@ class DeviceHistory(TypedDict):
# https://handmadesoftware.medium.com/streamlit-asyncio-and-mongodb-f85f77aea825 # https://handmadesoftware.medium.com/streamlit-asyncio-and-mongodb-f85f77aea825
class AppState(TypedDict): class AppState(TypedDict):
worker_thread: Thread worker_thread: Thread
message_queue: MemoryObjectReceiveStream[bytes] message_queue: MemoryObjectReceiveStream[Tuple[datetime, bytes]]
mqtt_message_queue: MemoryObjectReceiveStream[MqttMessage] mqtt_message_queue: MemoryObjectReceiveStream[Tuple[datetime, MqttMessage]]
task_group: TaskGroup task_group: TaskGroup
device_histories: Dict[Union[int, str], DeviceHistory] # device_id -> DeviceHistory device_histories: Dict[Union[int, str], DeviceHistory] # device_id -> DeviceHistory
refresh_inst: Instant
DEFAULT_BUSY_POLLING_INTERVAL: Final[float] = 0.001 BUSY_POLLING_INTERVAL_S: Final[float] = 0.001
BATCH_MESSAGE_INTERVAL_S: Final[float] = 0.1
NORMAL_REFRESH_INTERVAL_S: Final[float] = 0.5
QUEUE_BUFFER_SIZE: Final[int] = 32
UDP_SERVER_HOST: Final[str] = "localhost" UDP_SERVER_HOST: Final[str] = "localhost"
UDP_SERVER_PORT: Final[int] = 50_000 UDP_SERVER_PORT: Final[int] = 50_000
MAX_LENGTH = 600 MAX_LENGTH = 600
@ -166,8 +171,12 @@ def get_device_name(device_id: Union[int, str]) -> str:
@st.cache_resource @st.cache_resource
def resource(params: Any = None): def resource(params: Any = None):
set_ev = anyio.Event() set_ev = anyio.Event()
tx, rx = create_memory_object_stream[bytes]() tx, rx = create_memory_object_stream[Tuple[datetime, bytes]](
mqtt_tx, mqtt_rx = create_memory_object_stream[MqttMessage]() max_buffer_size=QUEUE_BUFFER_SIZE
)
mqtt_tx, mqtt_rx = create_memory_object_stream[Tuple[datetime, MqttMessage]](
max_buffer_size=QUEUE_BUFFER_SIZE
)
tg: Optional[TaskGroup] = None tg: Optional[TaskGroup] = None
async def udp_task(): async def udp_task():
@ -178,7 +187,8 @@ def resource(params: Any = None):
"UDP server listening on {}:{}", UDP_SERVER_HOST, UDP_SERVER_PORT "UDP server listening on {}:{}", UDP_SERVER_HOST, UDP_SERVER_PORT
) )
async for packet, _ in udp: async for packet, _ in udp:
await tx.send(packet) timestamp = datetime.now()
await tx.send((timestamp, packet))
async def mqtt_task(): async def mqtt_task():
async with MqttClient(MQTT_BROKER, port=MQTT_BROKER_PORT) as client: async with MqttClient(MQTT_BROKER, port=MQTT_BROKER_PORT) as client:
@ -190,7 +200,8 @@ def resource(params: Any = None):
TOPIC, TOPIC,
) )
async for message in client.messages: async for message in client.messages:
await mqtt_tx.send(message) timestamp = datetime.now()
await mqtt_tx.send((timestamp, message))
async def combined_task(): async def combined_task():
nonlocal set_ev, tg nonlocal set_ev, tg
@ -205,7 +216,7 @@ def resource(params: Any = None):
tr.start() tr.start()
while not set_ev.is_set(): while not set_ev.is_set():
sleep(DEFAULT_BUSY_POLLING_INTERVAL) sleep(BUSY_POLLING_INTERVAL_S)
logger.info("UDP and MQTT tasks initialized in single thread") logger.info("UDP and MQTT tasks initialized in single thread")
@ -215,7 +226,6 @@ def resource(params: Any = None):
"mqtt_message_queue": mqtt_rx, "mqtt_message_queue": mqtt_rx,
"task_group": unwrap(tg), "task_group": unwrap(tg),
"device_histories": {}, "device_histories": {},
"refresh_inst": Instant(),
} }
logger.info("Resource created") logger.info("Resource created")
return state return state
@ -243,7 +253,7 @@ def main():
device_histories.clear() device_histories.clear()
# https://docs.streamlit.io/develop/api-reference/layout # https://docs.streamlit.io/develop/api-reference/layout
st.title("MAX-BAND Visualizer") st.title("HR Visualizer")
with st.container(border=True): with st.container(border=True):
c1, c2 = st.columns(2) c1, c2 = st.columns(2)
with c1: with c1:
@ -270,15 +280,13 @@ def main():
else: else:
selected_devices = [] selected_devices = []
# Process available messages (no infinite loop)
# Process UDP messages (treat as device_id = 0) # Process UDP messages (treat as device_id = 0)
message_processed = False
try: try:
while True: # Process all available UDP messages while True:
message = state["message_queue"].receive_nowait() timestamp, packet = state["message_queue"].receive_nowait()
hr_value = parse_ble_hr_measurement(message) hr_value = parse_ble_hr_measurement(packet)
if hr_value is not None: if hr_value is not None:
now = datetime.now() now = timestamp
if UDP_DEVICE_ID not in device_histories: if UDP_DEVICE_ID not in device_histories:
device_histories[UDP_DEVICE_ID] = create_device_history() device_histories[UDP_DEVICE_ID] = create_device_history()
@ -288,17 +296,15 @@ def main():
dev_hist["hr_data"].append(float(hr_value)) dev_hist["hr_data"].append(float(hr_value))
logger.debug("UDP Device: HR={}", hr_value) logger.debug("UDP Device: HR={}", hr_value)
message_processed = True
except anyio.WouldBlock: except anyio.WouldBlock:
pass pass
# Process MQTT messages # Process MQTT messages
try: try:
while True: # Process all available MQTT messages while True:
mqtt_message = state["mqtt_message_queue"].receive_nowait() timestamp, mqtt_message = state["mqtt_message_queue"].receive_nowait()
if mqtt_message.payload: if mqtt_message.payload:
try: try:
# Ensure payload is bytes
payload_bytes = mqtt_message.payload payload_bytes = mqtt_message.payload
if isinstance(payload_bytes, str): if isinstance(payload_bytes, str):
payload_bytes = payload_bytes.encode("utf-8") payload_bytes = payload_bytes.encode("utf-8")
@ -308,9 +314,8 @@ def main():
hr_packet = HrPacket() hr_packet = HrPacket()
hr_packet.parse(payload_bytes) hr_packet.parse(payload_bytes)
now = datetime.now() now = timestamp
# Extract HR data based on packet type
device_id = None device_id = None
hr_value = None hr_value = None
@ -324,7 +329,6 @@ def main():
hr_value = packet.hr hr_value = packet.hr
if device_id is not None and hr_value is not None: if device_id is not None and hr_value is not None:
# Ensure device history exists
if device_id not in device_histories: if device_id not in device_histories:
device_histories[device_id] = create_device_history() device_histories[device_id] = create_device_history()
@ -333,18 +337,11 @@ def main():
dev_hist["hr_data"].append(float(hr_value)) dev_hist["hr_data"].append(float(hr_value))
logger.debug("Device {}: HR={}", device_id, hr_value) logger.debug("Device {}: HR={}", device_id, hr_value)
message_processed = True except (ValueError, TypeError, IndexError, UnicodeDecodeError) as e:
logger.error("MQTT protobuf message parsing: {}", e)
except Exception as e:
logger.error("Failed to parse MQTT protobuf message: {}", e)
except anyio.WouldBlock: except anyio.WouldBlock:
pass pass
# Auto-refresh the page if new data was processed
if message_processed:
sleep(0.1) # Small delay to batch multiple messages
st.rerun()
# Update visualization - HR Graphs # Update visualization - HR Graphs
if device_histories: if device_histories:
st.subheader("Heart Rate Data") st.subheader("Heart Rate Data")
@ -396,7 +393,7 @@ def main():
else: else:
st.info("No devices connected yet. Waiting for data...") st.info("No devices connected yet. Waiting for data...")
sleep(1) sleep(NORMAL_REFRESH_INTERVAL_S)
st.rerun() st.rerun()