Compare commits

..

3 Commits

325
main.py
View File

@ -11,6 +11,8 @@ from typing import (
TypedDict,
Any,
cast,
Dict,
Union,
)
from loguru import logger
@ -25,6 +27,7 @@ from threading import Thread
from time import sleep
from pydantic import BaseModel, computed_field
from datetime import datetime, timedelta
import struct
import awkward as ak
from awkward import Array as AwkwardArray, Record as AwkwardRecord
from app.model import AlgoReport
@ -32,35 +35,39 @@ from app.utils import Instant
from collections import deque
from dataclasses import dataclass
from aiomqtt import Client as MqttClient, Message as MqttMessage
from app.proto.hr_packet import HrPacket, HrOnlyPacket, HrPpgPacket, HrConfidence
import betterproto
SELECT_BAND_ID = 17
MQTT_BROKER: Final[str] = "192.168.2.189"
MQTT_BROKER: Final[str] = "weihua-iot.cn"
MQTT_BROKER_PORT: Final[int] = 1883
MAX_LENGTH = 600
TOPIC: Final[str] = "GwData"
NDArray = np.ndarray
TOPIC: Final[str] = "/hr/region/1/band/#"
UDP_DEVICE_ID: Final[int] = 0xFF
NDArray = np.ndarray
T = TypeVar("T")
class AppHistory(TypedDict):
class DeviceHistory(TypedDict):
timescape: deque[datetime]
hr_data: deque[float]
hr_conf: deque[float] # in %
# https://handmadesoftware.medium.com/streamlit-asyncio-and-mongodb-f85f77aea825
class AppState(TypedDict):
worker_thread: Thread
message_queue: MemoryObjectReceiveStream[bytes]
message_queue: MemoryObjectReceiveStream[Tuple[datetime, bytes]]
mqtt_message_queue: MemoryObjectReceiveStream[Tuple[datetime, MqttMessage]]
task_group: TaskGroup
history: AppHistory
refresh_inst: Instant
device_histories: Dict[Union[int, str], DeviceHistory] # device_id -> DeviceHistory
BUSY_POLLING_INTERVAL_S: Final[float] = 0.001
BATCH_MESSAGE_INTERVAL_S: Final[float] = 0.1
NORMAL_REFRESH_INTERVAL_S: Final[float] = 0.5
QUEUE_BUFFER_SIZE: Final[int] = 32
UDP_SERVER_HOST: Final[str] = "localhost"
UDP_SERVER_PORT: Final[int] = 50_000
MAX_LENGTH = 600
@ -75,28 +82,6 @@ def unwrap(value: Optional[T]) -> T:
return value
# /**
# * @brief Structure of the Heart Rate Measurement characteristic
# *
# * @see https://www.bluetooth.com/specifications/gss/
# * @see section 3.116 Heart Rate Measurement of the document: GATT Specification Supplement.
# */
# struct ble_hr_measurement_flag_t {
# // LSB first
# /*
# * 0: uint8_t
# * 1: uint16_t
# */
# bool heart_rate_value_format : 1;
# bool sensor_contact_detected : 1;
# bool sensor_contact_supported : 1;
# bool energy_expended_present : 1;
# bool rr_interval_present : 1;
# uint8_t reserved : 3;
# };
def parse_ble_hr_measurement(data: bytes) -> Optional[int]:
"""
Parse BLE Heart Rate Measurement characteristic data according to Bluetooth specification.
@ -106,6 +91,29 @@ def parse_ble_hr_measurement(data: bytes) -> Optional[int]:
Returns:
Heart rate value in BPM, or None if parsing fails
Note:
```cpp
/**
* @brief Structure of the Heart Rate Measurement characteristic
*
* @see https://www.bluetooth.com/specifications/gss/
* @see section 3.116 Heart Rate Measurement of the document: GATT Specification Supplement.
*/
struct ble_hr_measurement_flag_t {
// LSB first
/*
* 0: uint8_t
* 1: uint16_t
*/
bool heart_rate_value_format : 1;
bool sensor_contact_detected : 1;
bool sensor_contact_supported : 1;
bool energy_expended_present : 1;
bool rr_interval_present : 1;
uint8_t reserved : 3;
};
```
"""
if len(data) < 2:
return None
@ -131,39 +139,93 @@ def parse_ble_hr_measurement(data: bytes) -> Optional[int]:
return None
def hr_confidence_to_percentage(confidence: HrConfidence) -> float:
"""Convert HrConfidence enum to percentage value"""
if confidence == HrConfidence.ZERO:
return 0 # mid-point of [0,25)
elif confidence == HrConfidence.LOW:
return 37.5 # mid-point of [25,50)
elif confidence == HrConfidence.MEDIUM:
return 62.5 # mid-point of [50,75)
elif confidence == HrConfidence.HIGH:
return 100 # mid-point of (75,100]
else:
raise ValueError(f"Invalid HrConfidence: {confidence}")
def create_device_history() -> DeviceHistory:
"""Create a new device history structure"""
return {
"timescape": deque(maxlen=MAX_LENGTH),
"hr_data": deque(maxlen=MAX_LENGTH),
}
def get_device_name(device_id: Union[int, str]) -> str:
"""Get display name for device"""
if device_id == UDP_DEVICE_ID:
return "UDP Device"
return f"Device {device_id}"
@st.cache_resource
def resource(params: Any = None):
set_ev = anyio.Event()
tx, rx = create_memory_object_stream[bytes]()
tx, rx = create_memory_object_stream[Tuple[datetime, bytes]](
max_buffer_size=QUEUE_BUFFER_SIZE
)
mqtt_tx, mqtt_rx = create_memory_object_stream[Tuple[datetime, MqttMessage]](
max_buffer_size=QUEUE_BUFFER_SIZE
)
tg: Optional[TaskGroup] = None
async def poll_task():
nonlocal set_ev
nonlocal tg
tg = anyio.create_task_group()
set_ev.set()
async with tg:
async def udp_task():
async with await create_udp_socket(
local_host=UDP_SERVER_HOST, local_port=UDP_SERVER_PORT, reuse_port=True
) as udp:
logger.info(
"UDP server listening on {}:{}", UDP_SERVER_HOST, UDP_SERVER_PORT
)
async for packet, _ in udp:
await tx.send(packet)
timestamp = datetime.now()
await tx.send((timestamp, packet))
tr = Thread(target=anyio.run, args=(poll_task,))
async def mqtt_task():
async with MqttClient(MQTT_BROKER, port=MQTT_BROKER_PORT) as client:
await client.subscribe(TOPIC)
logger.info(
"Subscribed to MQTT broker {}:{} topic {}",
MQTT_BROKER,
MQTT_BROKER_PORT,
TOPIC,
)
async for message in client.messages:
timestamp = datetime.now()
await mqtt_tx.send((timestamp, message))
async def combined_task():
nonlocal set_ev, tg
tg = anyio.create_task_group()
set_ev.set()
async with tg:
async with anyio.create_task_group() as inner_tg:
inner_tg.start_soon(udp_task)
inner_tg.start_soon(mqtt_task)
tr = Thread(target=anyio.run, args=(combined_task,))
tr.start()
while not set_ev.is_set():
sleep(0.001)
logger.info("Poll task initialized")
sleep(BUSY_POLLING_INTERVAL_S)
logger.info("UDP and MQTT tasks initialized in single thread")
state: AppState = {
"worker_thread": tr,
"message_queue": rx,
"mqtt_message_queue": mqtt_rx,
"task_group": unwrap(tg),
"history": {
"timescape": deque(maxlen=MAX_LENGTH),
"hr_data": deque(maxlen=MAX_LENGTH),
"hr_conf": deque(maxlen=MAX_LENGTH),
},
"refresh_inst": Instant(),
"device_histories": {},
}
logger.info("Resource created")
return state
@ -171,23 +233,27 @@ def resource(params: Any = None):
def main():
state = resource()
history = state["history"]
device_histories = state["device_histories"]
def on_export():
file_name = f"history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.parquet"
logger.info(f"Exporting to {file_name}")
rec = ak.Record(history)
# Export all device histories
export_data = {
device_id: ak.Record(dev_hist)
for device_id, dev_hist in device_histories.items()
}
rec = ak.Record(export_data)
ak.to_parquet(rec, file_name)
def on_clear():
nonlocal history
nonlocal device_histories
logger.info("Clearing history")
history["timescape"].clear()
history["hr_data"].clear()
history["hr_conf"].clear()
device_histories.clear()
# https://docs.streamlit.io/develop/api-reference/layout
st.title("MAX-BAND Visualizer")
st.title("HR Visualizer")
with st.container(border=True):
c1, c2 = st.columns(2)
with c1:
@ -203,44 +269,133 @@ def main():
on_click=on_clear,
)
placeholder = st.empty()
md_placeholder = st.empty()
# Device selection
if device_histories:
selected_devices = st.multiselect(
"Select devices to display:",
options=list(device_histories.keys()),
default=list(device_histories.keys()),
format_func=get_device_name,
)
else:
selected_devices = []
while True:
# Process UDP messages (treat as device_id = 0)
try:
message = state["message_queue"].receive_nowait()
while True:
timestamp, packet = state["message_queue"].receive_nowait()
hr_value = parse_ble_hr_measurement(packet)
if hr_value is not None:
now = timestamp
if UDP_DEVICE_ID not in device_histories:
device_histories[UDP_DEVICE_ID] = create_device_history()
dev_hist = device_histories[UDP_DEVICE_ID]
dev_hist["timescape"].append(now)
dev_hist["hr_data"].append(float(hr_value))
logger.debug("UDP Device: HR={}", hr_value)
except anyio.WouldBlock:
continue
hr_value = parse_ble_hr_measurement(message)
if hr_value is None:
logger.error("Failed to parse heart rate measurement data")
pass
# Process MQTT messages
try:
while True:
timestamp, mqtt_message = state["mqtt_message_queue"].receive_nowait()
if mqtt_message.payload:
try:
payload_bytes = mqtt_message.payload
if isinstance(payload_bytes, str):
payload_bytes = payload_bytes.encode("utf-8")
elif not isinstance(payload_bytes, bytes):
continue
with placeholder.container():
history["hr_data"].append(float(hr_value))
history["hr_conf"].append(
100.0
) # Default confidence since we're not parsing it
fig_hr, fig_pd = st.tabs(["Heart Rate", "PD"])
hr_packet = HrPacket()
hr_packet.parse(payload_bytes)
with fig_hr:
st.plotly_chart(
go.Figure(
data=[
go.Scatter(
y=list(history["hr_data"]),
mode="lines",
name="HR",
),
go.Scatter(
y=list(history["hr_conf"]),
mode="lines",
name="HR Confidence",
),
now = timestamp
device_id = None
hr_value = None
if hr_packet.hr_only_packet:
packet = hr_packet.hr_only_packet
device_id = packet.id
hr_value = packet.hr
elif hr_packet.hr_ppg_packet:
packet = hr_packet.hr_ppg_packet
device_id = packet.id
hr_value = packet.hr
if device_id is not None and hr_value is not None:
if device_id not in device_histories:
device_histories[device_id] = create_device_history()
dev_hist = device_histories[device_id]
dev_hist["timescape"].append(now)
dev_hist["hr_data"].append(float(hr_value))
logger.debug("Device {}: HR={}", device_id, hr_value)
except (ValueError, TypeError, IndexError, UnicodeDecodeError) as e:
logger.error("MQTT protobuf message parsing: {}", e)
except anyio.WouldBlock:
pass
# Update visualization - HR Graphs
if device_histories:
st.subheader("Heart Rate Data")
# Create plots for selected devices
traces = []
colors = [
"red",
"green",
"blue",
"orange",
"purple",
"brown",
"pink",
"gray",
"olive",
"cyan",
]
for i, device_id in enumerate(selected_devices):
if device_id in device_histories:
dev_hist = device_histories[device_id]
if dev_hist["hr_data"] and dev_hist["timescape"]:
color = colors[i % len(colors)]
traces.append(
go.Scatter(
x=list(dev_hist["timescape"]),
y=list(dev_hist["hr_data"]),
mode="lines+markers",
name=get_device_name(device_id),
line=dict(color=color),
marker=dict(size=4),
)
)
if traces:
fig = go.Figure(data=traces)
fig.update_layout(
title="Heart Rate Monitor",
xaxis_title="Time",
yaxis_title="Heart Rate (BPM)",
hovermode="x unified",
showlegend=True,
height=500,
)
st.plotly_chart(fig, use_container_width=True)
else:
st.info("No heart rate data available for selected devices")
else:
st.info("No devices connected yet. Waiting for data...")
sleep(NORMAL_REFRESH_INTERVAL_S)
st.rerun()
if __name__ == "__main__":
main()