From ac14f7e45fb90670199cad13ba76f181cdb51227 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Tue, 11 Feb 2025 11:14:46 +0800 Subject: [PATCH] refactor: cleanup everything and ready to use as template --- app/model/__init__.py | 150 ++++++++++++++++++++++++++++++++++++ main.py | 173 +++++++----------------------------------- run.sh | 5 +- 3 files changed, 181 insertions(+), 147 deletions(-) create mode 100644 app/model/__init__.py diff --git a/app/model/__init__.py b/app/model/__init__.py new file mode 100644 index 0000000..296e54c --- /dev/null +++ b/app/model/__init__.py @@ -0,0 +1,150 @@ +from dataclasses import dataclass +from enum import IntEnum +import struct +from typing import ClassVar, Tuple +from pydantic import BaseModel, Field, computed_field + + +class AlgoOpMode(IntEnum): + """Equivalent to max::ALGO_OP_MODE""" + + CONTINUOUS_HRM_CONTINUOUS_SPO2 = 0x00 # Continuous HRM, continuous SpO2 + CONTINUOUS_HRM_ONE_SHOT_SPO2 = 0x01 # Continuous HRM, one-shot SpO2 + CONTINUOUS_HRM = 0x02 # Continuous HRM + SAMPLED_HRM = 0x03 # Sampled HRM + SAMPLED_HRM_ONE_SHOT_SPO2 = 0x04 # Sampled HRM, one-shot SpO2 + ACTIVITY_TRACKING_ONLY = 0x05 # Activity tracking only + SPO2_CALIBRATION = 0x06 # SpO2 calibration + + +class ActivateClass(IntEnum): + """Equivalent to max::ACTIVATE_CLASS""" + + REST = 0 + WALK = 1 + RUN = 2 + BIKE = 3 + + +class SPO2State(IntEnum): + """Equivalent to max::SPO2_STATE""" + + LED_ADJUSTMENT = 0 + COMPUTATION = 1 + SUCCESS = 2 + TIMEOUT = 3 + + +class SCDState(IntEnum): + """Equivalent to max::SCD_STATE""" + + UNDETECTED = 0 + OFF_SKIN = 1 + ON_SOME_SUBJECT = 2 + ON_SKIN = 3 + + +class AlgoModelData(BaseModel): + op_mode: AlgoOpMode + hr: int # uint16, 10x calculated heart rate + hr_conf: int # uint8, confidence level in % + rr: int # uint16, 10x RR interval in ms + rr_conf: int # uint8 + activity_class: ActivateClass + r: int # uint16, 1000x SpO2 R value + spo2_conf: int # uint8 + spo2: int # uint16, 10x SpO2 % + spo2_percent_complete: int # uint8 + spo2_low_signal_quality_flag: int # uint8 + spo2_motion_flag: int # uint8 + spo2_low_pi_flag: int # uint8 + spo2_unreliable_r_flag: int # uint8 + spo2_state: SPO2State + scd_contact_state: SCDState + reserved: int # uint32 + + # Format string for struct.unpack + _FORMAT: ClassVar[str] = " float: + """Heart rate in beats per minute""" + return self.hr / 10.0 + + @computed_field + def spo2_f(self) -> float: + """SpO2 percentage""" + return self.spo2 / 10.0 + + @computed_field + def r_f(self) -> float: + """SpO2 R value""" + return self.r / 1000.0 + + @computed_field + def rr_f(self) -> float: + """RR interval in milliseconds""" + return self.rr / 10.0 + + @classmethod + def from_bytes(cls, data: bytes) -> "AlgoModelData": + values = struct.unpack(cls._FORMAT, data) + return cls( + op_mode=values[0], + hr=values[1], + hr_conf=values[2], + rr=values[3], + rr_conf=values[4], + activity_class=values[5], + r=values[6], + spo2_conf=values[7], + spo2=values[8], + spo2_percent_complete=values[9], + spo2_low_signal_quality_flag=values[10], + spo2_motion_flag=values[11], + spo2_low_pi_flag=values[12], + spo2_unreliable_r_flag=values[13], + spo2_state=values[14], + scd_contact_state=values[15], + reserved=values[16], + ) + + +class AlgoReport(BaseModel): + led_1: int # uint32 + led_2: int # uint32 + led_3: int # uint32 + accel_x: int # int16 + accel_y: int # int16 + accel_z: int # int16 + data: AlgoModelData + + @classmethod + def unmarshal(cls, buf: bytes) -> "AlgoReport": + if len(buf) < 24 + struct.calcsize(AlgoModelData._FORMAT): + raise ValueError("Buffer too small") + + # Parse PPG values (3 bytes each, MSB first) + led_1 = int.from_bytes(buf[0:3], byteorder="little") + led_2 = int.from_bytes(buf[3:6], byteorder="little") + led_3 = int.from_bytes(buf[6:9], byteorder="little") + + # Skip unused PPG values (bytes 9-17) + + # Parse accelerometer values (2 bytes each, MSB first) + accel_x = int.from_bytes(buf[18:20], byteorder="little", signed=True) + accel_y = int.from_bytes(buf[20:22], byteorder="little", signed=True) + accel_z = int.from_bytes(buf[22:24], byteorder="little", signed=True) + + # Parse algorithm data + algo_data = AlgoModelData.from_bytes(buf[24:]) + + return cls( + led_1=led_1, + led_2=led_2, + led_3=led_3, + accel_x=accel_x, + accel_y=accel_y, + accel_z=accel_z, + data=algo_data, + ) diff --git a/main.py b/main.py index 934a634..058c60e 100644 --- a/main.py +++ b/main.py @@ -18,32 +18,29 @@ import numpy as np import plotly.graph_objects as go import streamlit as st import anyio -from anyio.abc import TaskGroup -from anyio import create_memory_object_stream +from anyio.abc import TaskGroup, UDPSocket +from anyio import create_memory_object_stream, create_udp_socket from anyio.streams.memory import MemoryObjectSendStream, MemoryObjectReceiveStream -from aiomqtt import Client as MqttClient, Message as MqttMessage from threading import Thread from time import sleep from pydantic import BaseModel, computed_field from datetime import datetime import awkward as ak from awkward import Array as AwkwardArray, Record as AwkwardRecord -import orjson +from app.model import AlgoReport +from collections import deque # https://handmadesoftware.medium.com/streamlit-asyncio-and-mongodb-f85f77aea825 class AppState(TypedDict): worker_thread: Thread - client: MqttClient - message_queue: MemoryObjectReceiveStream[MqttMessage] + message_queue: MemoryObjectReceiveStream[bytes] task_group: TaskGroup - history: dict[str, AwkwardArray] + history: deque[AlgoReport] -MQTT_BROKER: Final[str] = "192.168.2.189" -MQTT_BROKER_PORT: Final[int] = 1883 +UDP_LISTEN_PORT: Final[int] = 50_000 MAX_LENGTH = 600 -TOPIC: Final[str] = "GwData" NDArray = np.ndarray T = TypeVar("T") @@ -57,146 +54,46 @@ def unwrap(value: Optional[T]) -> T: @st.cache_resource def resource(params: Any = None): - client: Optional[MqttClient] = None - tx, rx = create_memory_object_stream[MqttMessage]() + set_ev = anyio.Event() + tx, rx = create_memory_object_stream[bytes]() tg: Optional[TaskGroup] = None - async def main(): + async def poll_task(): + nonlocal set_ev nonlocal tg - nonlocal client tg = anyio.create_task_group() + set_ev.set() async with tg: - client = MqttClient(MQTT_BROKER, port=MQTT_BROKER_PORT) - async with client: - await client.subscribe(TOPIC) - logger.info( - "Subscribed {}:{} to topic {}", MQTT_BROKER, MQTT_BROKER_PORT, TOPIC - ) - # https://aiomqtt.bo3hm.com/subscribing-to-a-topic.html - async for message in client.messages: - await tx.send(message) + async with await create_udp_socket( + local_port=UDP_LISTEN_PORT, reuse_port=True + ) as udp: + async for packet, _ in udp: + await tx.send(packet) - tr = Thread(target=anyio.run, args=(main,)) + tr = Thread(target=anyio.run, args=(poll_task,)) tr.start() - sleep(0.1) + while not set_ev.is_set(): + sleep(0.01) + logger.info("Poll task initialized") state: AppState = { "worker_thread": tr, - "client": unwrap(client), "message_queue": rx, "task_group": unwrap(tg), - "history": {}, + "history": deque(maxlen=MAX_LENGTH), } return state -class GwMessage(TypedDict): - v: int - mid: int - time: int - ip: str - mac: str - devices: list[Any] - rssi: int - - -class DeviceMessage(BaseModel): - mac: str - """ - Hex string, capital letters, e.g. "D6AF1CA9C491" - """ - service: str - """ - Hex string, capital letters, e.g. "180D" - """ - characteristic: str - """ - Hex string, capital letters, e.g. "2A37" - """ - value: str - """ - Hex string, capital letters, e.g. "0056" - """ - rssi: int - - @property - def value_bytes(self) -> bytes: - return bytes.fromhex(self.value) - - -def get_device_data(message: GwMessage) -> List[DeviceMessage]: - """ - devices - - [[5,"D6AF1CA9C491","180D","2A37","0056",-58],[5,"A09E1AE4E710","180D","2A37","0055",-50]] - - unknown, mac addr, service, characteristic, value (hex), rssi - """ - l: list[DeviceMessage] = [] - for d in message["devices"]: - x, mac, service, characteristic, value, rssi = d - l.append( - DeviceMessage( - mac=mac, - service=service, - characteristic=characteristic, - value=value, - rssi=rssi, - ) - ) - return l - - -def payload_to_hr(payload: bytes) -> int: - """ - ignore the first byte, parse the rest as a big-endian integer - - Bit 0 (Heart Rate Format) - 0: Heart rate value is 8 bits - 1: Heart rate value is 16 bits - Bit 3 (Energy Expended) - Indicates whether energy expended data is present - Bit 4 (RR Interval) - Indicates whether RR interval data is present - """ - flags = payload[0] - if flags & 0b00000001: - return int.from_bytes(payload[1:3], "big") - else: - return payload[1] - - def main(): state = resource() logger.info("Resource created") history = state["history"] - def push_new_message(message: GwMessage): - dms = get_device_data(message) - now = datetime.now() - for dm in dms: - rec = AwkwardRecord( - { - "time": now, - "value": payload_to_hr(dm.value_bytes), - "rssi": dm.rssi, - } - ) - if dm.mac not in history: - history[dm.mac] = AwkwardArray([rec]) - else: - history[dm.mac] = ak.concatenate([history[dm.mac], [rec]]) - if len(history[dm.mac]) > MAX_LENGTH: - history[dm.mac] = AwkwardArray(history[dm.mac][-MAX_LENGTH:]) - def on_export(): - now = datetime.now() - filename = f"export-{now.strftime('%Y-%m-%d-%H-%M-%S')}.parquet" - ak.to_parquet([history], filename) - logger.info("Export to {}", filename) + raise NotImplementedError def on_clear(): - history.clear() - logger.info("History cleared") + raise NotImplementedError st.button( "Export", help="Export the current data to a parquet file", on_click=on_export @@ -208,25 +105,9 @@ def main(): message = state["message_queue"].receive_nowait() except anyio.WouldBlock: continue - m: str - if isinstance(message.payload, str): - m = message.payload - elif isinstance(message.payload, bytes): - m = message.payload.decode("utf-8") - else: - logger.warning("Unknown message type: {}", type(message.payload)) - continue - d = cast(GwMessage, orjson.loads(m)) - push_new_message(d) - - def to_scatter(key: str, dev_history: AwkwardArray): - x = ak.to_numpy(dev_history["time"]) - y = ak.to_numpy(dev_history["value"]) - return go.Scatter(x=x, y=y, mode="lines+markers", name=key) - - scatters = [to_scatter(k, el) for k, el in history.items()] - fig = go.Figure(scatters) - pannel.plotly_chart(fig) + # TODO: plot + # fig = go.Figure(scatters) + # pannel.plotly_chart(fig) if __name__ == "__main__": diff --git a/run.sh b/run.sh index 086a355..19162a3 100755 --- a/run.sh +++ b/run.sh @@ -1 +1,4 @@ -python -m streamlit run main.py \ No newline at end of file +#!/usr/bin/env bash + +# python -m streamlit run main.py +python3.12 -m streamlit run main.py \ No newline at end of file