diff --git a/main.py b/main.py index 656008c..257dc7a 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,8 @@ from typing import ( TypedDict, Any, cast, + Dict, + Union, ) from loguru import logger @@ -32,35 +34,35 @@ from app.utils import Instant from collections import deque from dataclasses import dataclass from aiomqtt import Client as MqttClient, Message as MqttMessage +from app.proto.hr_packet import HrPacket, HrOnlyPacket, HrPpgPacket, HrConfidence -SELECT_BAND_ID = 17 - - -MQTT_BROKER: Final[str] = "192.168.2.189" +MQTT_BROKER: Final[str] = "weihua-iot.cn" MQTT_BROKER_PORT: Final[int] = 1883 MAX_LENGTH = 600 -TOPIC: Final[str] = "GwData" -NDArray = np.ndarray +TOPIC: Final[str] = "/hr/region/1/band/#" +UDP_DEVICE_ID: Final[int] = 0xFF +NDArray = np.ndarray T = TypeVar("T") -class AppHistory(TypedDict): +class DeviceHistory(TypedDict): timescape: deque[datetime] hr_data: deque[float] - hr_conf: deque[float] # in % # https://handmadesoftware.medium.com/streamlit-asyncio-and-mongodb-f85f77aea825 class AppState(TypedDict): worker_thread: Thread message_queue: MemoryObjectReceiveStream[bytes] + mqtt_message_queue: MemoryObjectReceiveStream[MqttMessage] task_group: TaskGroup - history: AppHistory + device_histories: Dict[Union[int, str], DeviceHistory] # device_id -> DeviceHistory refresh_inst: Instant +DEFAULT_BUSY_POLLING_INTERVAL: Final[float] = 0.001 UDP_SERVER_HOST: Final[str] = "localhost" UDP_SERVER_PORT: Final[int] = 50_000 MAX_LENGTH = 600 @@ -75,28 +77,6 @@ def unwrap(value: Optional[T]) -> T: return value -# /** -# * @brief Structure of the Heart Rate Measurement characteristic -# * -# * @see https://www.bluetooth.com/specifications/gss/ -# * @see section 3.116 Heart Rate Measurement of the document: GATT Specification Supplement. -# */ -# struct ble_hr_measurement_flag_t { -# // LSB first - -# /* -# * 0: uint8_t -# * 1: uint16_t -# */ -# bool heart_rate_value_format : 1; -# bool sensor_contact_detected : 1; -# bool sensor_contact_supported : 1; -# bool energy_expended_present : 1; -# bool rr_interval_present : 1; -# uint8_t reserved : 3; -# }; - - def parse_ble_hr_measurement(data: bytes) -> Optional[int]: """ Parse BLE Heart Rate Measurement characteristic data according to Bluetooth specification. @@ -106,6 +86,29 @@ def parse_ble_hr_measurement(data: bytes) -> Optional[int]: Returns: Heart rate value in BPM, or None if parsing fails + Note: + ```cpp + /** + * @brief Structure of the Heart Rate Measurement characteristic + * + * @see https://www.bluetooth.com/specifications/gss/ + * @see section 3.116 Heart Rate Measurement of the document: GATT Specification Supplement. + */ + struct ble_hr_measurement_flag_t { + // LSB first + + /* + * 0: uint8_t + * 1: uint16_t + */ + bool heart_rate_value_format : 1; + bool sensor_contact_detected : 1; + bool sensor_contact_supported : 1; + bool energy_expended_present : 1; + bool rr_interval_present : 1; + uint8_t reserved : 3; + }; + ``` """ if len(data) < 2: return None @@ -131,38 +134,87 @@ def parse_ble_hr_measurement(data: bytes) -> Optional[int]: return None +def hr_confidence_to_percentage(confidence: HrConfidence) -> float: + """Convert HrConfidence enum to percentage value""" + if confidence == HrConfidence.ZERO: + return 0 # mid-point of [0,25) + elif confidence == HrConfidence.LOW: + return 37.5 # mid-point of [25,50) + elif confidence == HrConfidence.MEDIUM: + return 62.5 # mid-point of [50,75) + elif confidence == HrConfidence.HIGH: + return 100 # mid-point of (75,100] + else: + raise ValueError(f"Invalid HrConfidence: {confidence}") + + +def create_device_history() -> DeviceHistory: + """Create a new device history structure""" + return { + "timescape": deque(maxlen=MAX_LENGTH), + "hr_data": deque(maxlen=MAX_LENGTH), + } + + +def get_device_name(device_id: Union[int, str]) -> str: + """Get display name for device""" + if device_id == UDP_DEVICE_ID: + return "UDP Device" + return f"Device {device_id}" + + @st.cache_resource def resource(params: Any = None): set_ev = anyio.Event() tx, rx = create_memory_object_stream[bytes]() + mqtt_tx, mqtt_rx = create_memory_object_stream[MqttMessage]() tg: Optional[TaskGroup] = None - async def poll_task(): - nonlocal set_ev - nonlocal tg + async def udp_task(): + async with await create_udp_socket( + local_host=UDP_SERVER_HOST, local_port=UDP_SERVER_PORT, reuse_port=True + ) as udp: + logger.info( + "UDP server listening on {}:{}", UDP_SERVER_HOST, UDP_SERVER_PORT + ) + async for packet, _ in udp: + await tx.send(packet) + + async def mqtt_task(): + async with MqttClient(MQTT_BROKER, port=MQTT_BROKER_PORT) as client: + await client.subscribe(TOPIC) + logger.info( + "Subscribed to MQTT broker {}:{} topic {}", + MQTT_BROKER, + MQTT_BROKER_PORT, + TOPIC, + ) + async for message in client.messages: + await mqtt_tx.send(message) + + async def combined_task(): + nonlocal set_ev, tg tg = anyio.create_task_group() set_ev.set() async with tg: - async with await create_udp_socket( - local_host=UDP_SERVER_HOST, local_port=UDP_SERVER_PORT, reuse_port=True - ) as udp: - async for packet, _ in udp: - await tx.send(packet) + async with anyio.create_task_group() as inner_tg: + inner_tg.start_soon(udp_task) + inner_tg.start_soon(mqtt_task) - tr = Thread(target=anyio.run, args=(poll_task,)) + tr = Thread(target=anyio.run, args=(combined_task,)) tr.start() + while not set_ev.is_set(): - sleep(0.001) - logger.info("Poll task initialized") + sleep(DEFAULT_BUSY_POLLING_INTERVAL) + + logger.info("UDP and MQTT tasks initialized in single thread") + state: AppState = { "worker_thread": tr, "message_queue": rx, + "mqtt_message_queue": mqtt_rx, "task_group": unwrap(tg), - "history": { - "timescape": deque(maxlen=MAX_LENGTH), - "hr_data": deque(maxlen=MAX_LENGTH), - "hr_conf": deque(maxlen=MAX_LENGTH), - }, + "device_histories": {}, "refresh_inst": Instant(), } logger.info("Resource created") @@ -171,20 +223,24 @@ def resource(params: Any = None): def main(): state = resource() - history = state["history"] + device_histories = state["device_histories"] def on_export(): file_name = f"history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.parquet" logger.info(f"Exporting to {file_name}") - rec = ak.Record(history) + + # Export all device histories + export_data = { + device_id: ak.Record(dev_hist) + for device_id, dev_hist in device_histories.items() + } + rec = ak.Record(export_data) ak.to_parquet(rec, file_name) def on_clear(): - nonlocal history + nonlocal device_histories logger.info("Clearing history") - history["timescape"].clear() - history["hr_data"].clear() - history["hr_conf"].clear() + device_histories.clear() # https://docs.streamlit.io/develop/api-reference/layout st.title("MAX-BAND Visualizer") @@ -203,43 +259,132 @@ def main(): on_click=on_clear, ) + # Device selection + if device_histories: + selected_devices = st.multiselect( + "Select devices to display:", + options=list(device_histories.keys()), + default=list(device_histories.keys()), + format_func=get_device_name, + ) + else: + selected_devices = [] + placeholder = st.empty() - md_placeholder = st.empty() while True: + # Process UDP messages (treat as device_id = 0) try: message = state["message_queue"].receive_nowait() + hr_value = parse_ble_hr_measurement(message) + if hr_value is not None: + now = datetime.now() + + if UDP_DEVICE_ID not in device_histories: + device_histories[UDP_DEVICE_ID] = create_device_history() + + dev_hist = device_histories[UDP_DEVICE_ID] + dev_hist["timescape"].append(now) + dev_hist["hr_data"].append(float(hr_value)) + + logger.debug("UDP Device: HR={}", hr_value) except anyio.WouldBlock: - continue - hr_value = parse_ble_hr_measurement(message) - if hr_value is None: - logger.error("Failed to parse heart rate measurement data") - continue + pass + try: + mqtt_message = state["mqtt_message_queue"].receive_nowait() + if mqtt_message.payload: + try: + # Ensure payload is bytes + payload_bytes = mqtt_message.payload + if isinstance(payload_bytes, str): + payload_bytes = payload_bytes.encode("utf-8") + elif not isinstance(payload_bytes, bytes): + continue + + hr_packet = HrPacket() + hr_packet.parse(payload_bytes) + + now = datetime.now() + + # Extract HR data based on packet type + device_id = None + hr_value = None + + if hr_packet.hr_only_packet: + packet = hr_packet.hr_only_packet + device_id = packet.id + hr_value = packet.hr + elif hr_packet.hr_ppg_packet: + packet = hr_packet.hr_ppg_packet + device_id = packet.id + hr_value = packet.hr + + if device_id is not None and hr_value is not None: + # Ensure device history exists + if device_id not in device_histories: + device_histories[device_id] = create_device_history() + + dev_hist = device_histories[device_id] + dev_hist["timescape"].append(now) + dev_hist["hr_data"].append(float(hr_value)) + + logger.debug("Device {}: HR={}", device_id, hr_value) + + except Exception as e: + logger.error("Failed to parse MQTT protobuf message: {}", e) + except anyio.WouldBlock: + pass + + # Update visualization - HR Graphs with placeholder.container(): - history["hr_data"].append(float(hr_value)) - history["hr_conf"].append( - 100.0 - ) # Default confidence since we're not parsing it - fig_hr, fig_pd = st.tabs(["Heart Rate", "PD"]) + if device_histories: + st.subheader("Heart Rate Data") - with fig_hr: - st.plotly_chart( - go.Figure( - data=[ - go.Scatter( - y=list(history["hr_data"]), - mode="lines", - name="HR", - ), - go.Scatter( - y=list(history["hr_conf"]), - mode="lines", - name="HR Confidence", - ), - ] + # Create plots for selected devices + traces = [] + colors = [ + "red", + "green", + "blue", + "orange", + "purple", + "brown", + "pink", + "gray", + "olive", + "cyan", + ] + + for i, device_id in enumerate(selected_devices): + if device_id in device_histories: + dev_hist = device_histories[device_id] + if dev_hist["hr_data"] and dev_hist["timescape"]: + color = colors[i % len(colors)] + traces.append( + go.Scatter( + x=list(dev_hist["timescape"]), + y=list(dev_hist["hr_data"]), + mode="lines+markers", + name=get_device_name(device_id), + line=dict(color=color), + marker=dict(size=4), + ) + ) + + if traces: + fig = go.Figure(data=traces) + fig.update_layout( + title="Heart Rate Monitor", + xaxis_title="Time", + yaxis_title="Heart Rate (BPM)", + hovermode="x unified", + showlegend=True, + height=500, ) - ) + st.plotly_chart(fig, use_container_width=True) + else: + st.info("No heart rate data available for selected devices") if __name__ == "__main__":