refactor: cleanup everything and ready to use as template

This commit is contained in:
2025-02-11 11:14:46 +08:00
parent 83c7daefb9
commit ac14f7e45f
3 changed files with 181 additions and 147 deletions

150
app/model/__init__.py Normal file
View File

@ -0,0 +1,150 @@
from dataclasses import dataclass
from enum import IntEnum
import struct
from typing import ClassVar, Tuple
from pydantic import BaseModel, Field, computed_field
class AlgoOpMode(IntEnum):
"""Equivalent to max::ALGO_OP_MODE"""
CONTINUOUS_HRM_CONTINUOUS_SPO2 = 0x00 # Continuous HRM, continuous SpO2
CONTINUOUS_HRM_ONE_SHOT_SPO2 = 0x01 # Continuous HRM, one-shot SpO2
CONTINUOUS_HRM = 0x02 # Continuous HRM
SAMPLED_HRM = 0x03 # Sampled HRM
SAMPLED_HRM_ONE_SHOT_SPO2 = 0x04 # Sampled HRM, one-shot SpO2
ACTIVITY_TRACKING_ONLY = 0x05 # Activity tracking only
SPO2_CALIBRATION = 0x06 # SpO2 calibration
class ActivateClass(IntEnum):
"""Equivalent to max::ACTIVATE_CLASS"""
REST = 0
WALK = 1
RUN = 2
BIKE = 3
class SPO2State(IntEnum):
"""Equivalent to max::SPO2_STATE"""
LED_ADJUSTMENT = 0
COMPUTATION = 1
SUCCESS = 2
TIMEOUT = 3
class SCDState(IntEnum):
"""Equivalent to max::SCD_STATE"""
UNDETECTED = 0
OFF_SKIN = 1
ON_SOME_SUBJECT = 2
ON_SKIN = 3
class AlgoModelData(BaseModel):
op_mode: AlgoOpMode
hr: int # uint16, 10x calculated heart rate
hr_conf: int # uint8, confidence level in %
rr: int # uint16, 10x RR interval in ms
rr_conf: int # uint8
activity_class: ActivateClass
r: int # uint16, 1000x SpO2 R value
spo2_conf: int # uint8
spo2: int # uint16, 10x SpO2 %
spo2_percent_complete: int # uint8
spo2_low_signal_quality_flag: int # uint8
spo2_motion_flag: int # uint8
spo2_low_pi_flag: int # uint8
spo2_unreliable_r_flag: int # uint8
spo2_state: SPO2State
scd_contact_state: SCDState
reserved: int # uint32
# Format string for struct.unpack
_FORMAT: ClassVar[str] = "<BHBHBBHBBBBBBBBBL" # < for little-endian
@computed_field
def hr_f(self) -> float:
"""Heart rate in beats per minute"""
return self.hr / 10.0
@computed_field
def spo2_f(self) -> float:
"""SpO2 percentage"""
return self.spo2 / 10.0
@computed_field
def r_f(self) -> float:
"""SpO2 R value"""
return self.r / 1000.0
@computed_field
def rr_f(self) -> float:
"""RR interval in milliseconds"""
return self.rr / 10.0
@classmethod
def from_bytes(cls, data: bytes) -> "AlgoModelData":
values = struct.unpack(cls._FORMAT, data)
return cls(
op_mode=values[0],
hr=values[1],
hr_conf=values[2],
rr=values[3],
rr_conf=values[4],
activity_class=values[5],
r=values[6],
spo2_conf=values[7],
spo2=values[8],
spo2_percent_complete=values[9],
spo2_low_signal_quality_flag=values[10],
spo2_motion_flag=values[11],
spo2_low_pi_flag=values[12],
spo2_unreliable_r_flag=values[13],
spo2_state=values[14],
scd_contact_state=values[15],
reserved=values[16],
)
class AlgoReport(BaseModel):
led_1: int # uint32
led_2: int # uint32
led_3: int # uint32
accel_x: int # int16
accel_y: int # int16
accel_z: int # int16
data: AlgoModelData
@classmethod
def unmarshal(cls, buf: bytes) -> "AlgoReport":
if len(buf) < 24 + struct.calcsize(AlgoModelData._FORMAT):
raise ValueError("Buffer too small")
# Parse PPG values (3 bytes each, MSB first)
led_1 = int.from_bytes(buf[0:3], byteorder="little")
led_2 = int.from_bytes(buf[3:6], byteorder="little")
led_3 = int.from_bytes(buf[6:9], byteorder="little")
# Skip unused PPG values (bytes 9-17)
# Parse accelerometer values (2 bytes each, MSB first)
accel_x = int.from_bytes(buf[18:20], byteorder="little", signed=True)
accel_y = int.from_bytes(buf[20:22], byteorder="little", signed=True)
accel_z = int.from_bytes(buf[22:24], byteorder="little", signed=True)
# Parse algorithm data
algo_data = AlgoModelData.from_bytes(buf[24:])
return cls(
led_1=led_1,
led_2=led_2,
led_3=led_3,
accel_x=accel_x,
accel_y=accel_y,
accel_z=accel_z,
data=algo_data,
)

173
main.py
View File

@ -18,32 +18,29 @@ import numpy as np
import plotly.graph_objects as go import plotly.graph_objects as go
import streamlit as st import streamlit as st
import anyio import anyio
from anyio.abc import TaskGroup from anyio.abc import TaskGroup, UDPSocket
from anyio import create_memory_object_stream from anyio import create_memory_object_stream, create_udp_socket
from anyio.streams.memory import MemoryObjectSendStream, MemoryObjectReceiveStream from anyio.streams.memory import MemoryObjectSendStream, MemoryObjectReceiveStream
from aiomqtt import Client as MqttClient, Message as MqttMessage
from threading import Thread from threading import Thread
from time import sleep from time import sleep
from pydantic import BaseModel, computed_field from pydantic import BaseModel, computed_field
from datetime import datetime from datetime import datetime
import awkward as ak import awkward as ak
from awkward import Array as AwkwardArray, Record as AwkwardRecord from awkward import Array as AwkwardArray, Record as AwkwardRecord
import orjson from app.model import AlgoReport
from collections import deque
# https://handmadesoftware.medium.com/streamlit-asyncio-and-mongodb-f85f77aea825 # https://handmadesoftware.medium.com/streamlit-asyncio-and-mongodb-f85f77aea825
class AppState(TypedDict): class AppState(TypedDict):
worker_thread: Thread worker_thread: Thread
client: MqttClient message_queue: MemoryObjectReceiveStream[bytes]
message_queue: MemoryObjectReceiveStream[MqttMessage]
task_group: TaskGroup task_group: TaskGroup
history: dict[str, AwkwardArray] history: deque[AlgoReport]
MQTT_BROKER: Final[str] = "192.168.2.189" UDP_LISTEN_PORT: Final[int] = 50_000
MQTT_BROKER_PORT: Final[int] = 1883
MAX_LENGTH = 600 MAX_LENGTH = 600
TOPIC: Final[str] = "GwData"
NDArray = np.ndarray NDArray = np.ndarray
T = TypeVar("T") T = TypeVar("T")
@ -57,146 +54,46 @@ def unwrap(value: Optional[T]) -> T:
@st.cache_resource @st.cache_resource
def resource(params: Any = None): def resource(params: Any = None):
client: Optional[MqttClient] = None set_ev = anyio.Event()
tx, rx = create_memory_object_stream[MqttMessage]() tx, rx = create_memory_object_stream[bytes]()
tg: Optional[TaskGroup] = None tg: Optional[TaskGroup] = None
async def main(): async def poll_task():
nonlocal set_ev
nonlocal tg nonlocal tg
nonlocal client
tg = anyio.create_task_group() tg = anyio.create_task_group()
set_ev.set()
async with tg: async with tg:
client = MqttClient(MQTT_BROKER, port=MQTT_BROKER_PORT) async with await create_udp_socket(
async with client: local_port=UDP_LISTEN_PORT, reuse_port=True
await client.subscribe(TOPIC) ) as udp:
logger.info( async for packet, _ in udp:
"Subscribed {}:{} to topic {}", MQTT_BROKER, MQTT_BROKER_PORT, TOPIC await tx.send(packet)
)
# https://aiomqtt.bo3hm.com/subscribing-to-a-topic.html
async for message in client.messages:
await tx.send(message)
tr = Thread(target=anyio.run, args=(main,)) tr = Thread(target=anyio.run, args=(poll_task,))
tr.start() tr.start()
sleep(0.1) while not set_ev.is_set():
sleep(0.01)
logger.info("Poll task initialized")
state: AppState = { state: AppState = {
"worker_thread": tr, "worker_thread": tr,
"client": unwrap(client),
"message_queue": rx, "message_queue": rx,
"task_group": unwrap(tg), "task_group": unwrap(tg),
"history": {}, "history": deque(maxlen=MAX_LENGTH),
} }
return state return state
class GwMessage(TypedDict):
v: int
mid: int
time: int
ip: str
mac: str
devices: list[Any]
rssi: int
class DeviceMessage(BaseModel):
mac: str
"""
Hex string, capital letters, e.g. "D6AF1CA9C491"
"""
service: str
"""
Hex string, capital letters, e.g. "180D"
"""
characteristic: str
"""
Hex string, capital letters, e.g. "2A37"
"""
value: str
"""
Hex string, capital letters, e.g. "0056"
"""
rssi: int
@property
def value_bytes(self) -> bytes:
return bytes.fromhex(self.value)
def get_device_data(message: GwMessage) -> List[DeviceMessage]:
"""
devices
[[5,"D6AF1CA9C491","180D","2A37","0056",-58],[5,"A09E1AE4E710","180D","2A37","0055",-50]]
unknown, mac addr, service, characteristic, value (hex), rssi
"""
l: list[DeviceMessage] = []
for d in message["devices"]:
x, mac, service, characteristic, value, rssi = d
l.append(
DeviceMessage(
mac=mac,
service=service,
characteristic=characteristic,
value=value,
rssi=rssi,
)
)
return l
def payload_to_hr(payload: bytes) -> int:
"""
ignore the first byte, parse the rest as a big-endian integer
Bit 0 (Heart Rate Format)
0: Heart rate value is 8 bits
1: Heart rate value is 16 bits
Bit 3 (Energy Expended)
Indicates whether energy expended data is present
Bit 4 (RR Interval)
Indicates whether RR interval data is present
"""
flags = payload[0]
if flags & 0b00000001:
return int.from_bytes(payload[1:3], "big")
else:
return payload[1]
def main(): def main():
state = resource() state = resource()
logger.info("Resource created") logger.info("Resource created")
history = state["history"] history = state["history"]
def push_new_message(message: GwMessage):
dms = get_device_data(message)
now = datetime.now()
for dm in dms:
rec = AwkwardRecord(
{
"time": now,
"value": payload_to_hr(dm.value_bytes),
"rssi": dm.rssi,
}
)
if dm.mac not in history:
history[dm.mac] = AwkwardArray([rec])
else:
history[dm.mac] = ak.concatenate([history[dm.mac], [rec]])
if len(history[dm.mac]) > MAX_LENGTH:
history[dm.mac] = AwkwardArray(history[dm.mac][-MAX_LENGTH:])
def on_export(): def on_export():
now = datetime.now() raise NotImplementedError
filename = f"export-{now.strftime('%Y-%m-%d-%H-%M-%S')}.parquet"
ak.to_parquet([history], filename)
logger.info("Export to {}", filename)
def on_clear(): def on_clear():
history.clear() raise NotImplementedError
logger.info("History cleared")
st.button( st.button(
"Export", help="Export the current data to a parquet file", on_click=on_export "Export", help="Export the current data to a parquet file", on_click=on_export
@ -208,25 +105,9 @@ def main():
message = state["message_queue"].receive_nowait() message = state["message_queue"].receive_nowait()
except anyio.WouldBlock: except anyio.WouldBlock:
continue continue
m: str # TODO: plot
if isinstance(message.payload, str): # fig = go.Figure(scatters)
m = message.payload # pannel.plotly_chart(fig)
elif isinstance(message.payload, bytes):
m = message.payload.decode("utf-8")
else:
logger.warning("Unknown message type: {}", type(message.payload))
continue
d = cast(GwMessage, orjson.loads(m))
push_new_message(d)
def to_scatter(key: str, dev_history: AwkwardArray):
x = ak.to_numpy(dev_history["time"])
y = ak.to_numpy(dev_history["value"])
return go.Scatter(x=x, y=y, mode="lines+markers", name=key)
scatters = [to_scatter(k, el) for k, el in history.items()]
fig = go.Figure(scatters)
pannel.plotly_chart(fig)
if __name__ == "__main__": if __name__ == "__main__":

5
run.sh
View File

@ -1 +1,4 @@
python -m streamlit run main.py #!/usr/bin/env bash
# python -m streamlit run main.py
python3.12 -m streamlit run main.py