From 3d53f00ec9dcb555621509479a8ba42973e5e448 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Tue, 10 Jun 2025 18:41:55 +0800 Subject: [PATCH] refactor: Update message handling to include timestamps and improve polling intervals in heart rate monitoring --- main.py | 65 +++++++++++++++++++++++++++------------------------------ 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/main.py b/main.py index 99d87b1..be5a2ce 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,7 @@ from threading import Thread from time import sleep from pydantic import BaseModel, computed_field from datetime import datetime, timedelta +import struct import awkward as ak from awkward import Array as AwkwardArray, Record as AwkwardRecord from app.model import AlgoReport @@ -35,6 +36,7 @@ from collections import deque from dataclasses import dataclass from aiomqtt import Client as MqttClient, Message as MqttMessage from app.proto.hr_packet import HrPacket, HrOnlyPacket, HrPpgPacket, HrConfidence +import betterproto MQTT_BROKER: Final[str] = "weihua-iot.cn" @@ -55,14 +57,17 @@ class DeviceHistory(TypedDict): # https://handmadesoftware.medium.com/streamlit-asyncio-and-mongodb-f85f77aea825 class AppState(TypedDict): worker_thread: Thread - message_queue: MemoryObjectReceiveStream[bytes] - mqtt_message_queue: MemoryObjectReceiveStream[MqttMessage] + message_queue: MemoryObjectReceiveStream[Tuple[datetime, bytes]] + mqtt_message_queue: MemoryObjectReceiveStream[Tuple[datetime, MqttMessage]] task_group: TaskGroup device_histories: Dict[Union[int, str], DeviceHistory] # device_id -> DeviceHistory - refresh_inst: Instant -DEFAULT_BUSY_POLLING_INTERVAL: Final[float] = 0.001 +BUSY_POLLING_INTERVAL_S: Final[float] = 0.001 +BATCH_MESSAGE_INTERVAL_S: Final[float] = 0.1 +NORMAL_REFRESH_INTERVAL_S: Final[float] = 0.5 +QUEUE_BUFFER_SIZE: Final[int] = 32 + UDP_SERVER_HOST: Final[str] = "localhost" UDP_SERVER_PORT: Final[int] = 50_000 MAX_LENGTH = 600 @@ -166,8 +171,12 @@ def get_device_name(device_id: Union[int, str]) -> str: @st.cache_resource def resource(params: Any = None): set_ev = anyio.Event() - tx, rx = create_memory_object_stream[bytes]() - mqtt_tx, mqtt_rx = create_memory_object_stream[MqttMessage]() + tx, rx = create_memory_object_stream[Tuple[datetime, bytes]]( + max_buffer_size=QUEUE_BUFFER_SIZE + ) + mqtt_tx, mqtt_rx = create_memory_object_stream[Tuple[datetime, MqttMessage]]( + max_buffer_size=QUEUE_BUFFER_SIZE + ) tg: Optional[TaskGroup] = None async def udp_task(): @@ -178,7 +187,8 @@ def resource(params: Any = None): "UDP server listening on {}:{}", UDP_SERVER_HOST, UDP_SERVER_PORT ) async for packet, _ in udp: - await tx.send(packet) + timestamp = datetime.now() + await tx.send((timestamp, packet)) async def mqtt_task(): async with MqttClient(MQTT_BROKER, port=MQTT_BROKER_PORT) as client: @@ -190,7 +200,8 @@ def resource(params: Any = None): TOPIC, ) async for message in client.messages: - await mqtt_tx.send(message) + timestamp = datetime.now() + await mqtt_tx.send((timestamp, message)) async def combined_task(): nonlocal set_ev, tg @@ -205,7 +216,7 @@ def resource(params: Any = None): tr.start() while not set_ev.is_set(): - sleep(DEFAULT_BUSY_POLLING_INTERVAL) + sleep(BUSY_POLLING_INTERVAL_S) logger.info("UDP and MQTT tasks initialized in single thread") @@ -215,7 +226,6 @@ def resource(params: Any = None): "mqtt_message_queue": mqtt_rx, "task_group": unwrap(tg), "device_histories": {}, - "refresh_inst": Instant(), } logger.info("Resource created") return state @@ -243,7 +253,7 @@ def main(): device_histories.clear() # https://docs.streamlit.io/develop/api-reference/layout - st.title("MAX-BAND Visualizer") + st.title("HR Visualizer") with st.container(border=True): c1, c2 = st.columns(2) with c1: @@ -270,15 +280,13 @@ def main(): else: selected_devices = [] - # Process available messages (no infinite loop) # Process UDP messages (treat as device_id = 0) - message_processed = False try: - while True: # Process all available UDP messages - message = state["message_queue"].receive_nowait() - hr_value = parse_ble_hr_measurement(message) + while True: + timestamp, packet = state["message_queue"].receive_nowait() + hr_value = parse_ble_hr_measurement(packet) if hr_value is not None: - now = datetime.now() + now = timestamp if UDP_DEVICE_ID not in device_histories: device_histories[UDP_DEVICE_ID] = create_device_history() @@ -288,17 +296,15 @@ def main(): dev_hist["hr_data"].append(float(hr_value)) logger.debug("UDP Device: HR={}", hr_value) - message_processed = True except anyio.WouldBlock: pass # Process MQTT messages try: - while True: # Process all available MQTT messages - mqtt_message = state["mqtt_message_queue"].receive_nowait() + while True: + timestamp, mqtt_message = state["mqtt_message_queue"].receive_nowait() if mqtt_message.payload: try: - # Ensure payload is bytes payload_bytes = mqtt_message.payload if isinstance(payload_bytes, str): payload_bytes = payload_bytes.encode("utf-8") @@ -308,9 +314,8 @@ def main(): hr_packet = HrPacket() hr_packet.parse(payload_bytes) - now = datetime.now() + now = timestamp - # Extract HR data based on packet type device_id = None hr_value = None @@ -324,7 +329,6 @@ def main(): hr_value = packet.hr if device_id is not None and hr_value is not None: - # Ensure device history exists if device_id not in device_histories: device_histories[device_id] = create_device_history() @@ -333,18 +337,11 @@ def main(): dev_hist["hr_data"].append(float(hr_value)) logger.debug("Device {}: HR={}", device_id, hr_value) - message_processed = True - - except Exception as e: - logger.error("Failed to parse MQTT protobuf message: {}", e) + except (ValueError, TypeError, IndexError, UnicodeDecodeError) as e: + logger.error("MQTT protobuf message parsing: {}", e) except anyio.WouldBlock: pass - # Auto-refresh the page if new data was processed - if message_processed: - sleep(0.1) # Small delay to batch multiple messages - st.rerun() - # Update visualization - HR Graphs if device_histories: st.subheader("Heart Rate Data") @@ -396,7 +393,7 @@ def main(): else: st.info("No devices connected yet. Waiting for data...") - sleep(1) + sleep(NORMAL_REFRESH_INTERVAL_S) st.rerun()