Files
hr_visualize/main.py

402 lines
12 KiB
Python

from typing import (
Annotated,
AsyncGenerator,
Final,
Generator,
List,
Literal,
Optional,
Tuple,
TypeVar,
TypedDict,
Any,
cast,
Dict,
Union,
)
from loguru import logger
import numpy as np
import plotly.graph_objects as go
import streamlit as st
import anyio
from anyio.abc import TaskGroup, UDPSocket
from anyio import create_memory_object_stream, create_udp_socket
from anyio.streams.memory import MemoryObjectSendStream, MemoryObjectReceiveStream
from threading import Thread
from time import sleep
from pydantic import BaseModel, computed_field
from datetime import datetime, timedelta
import struct
import awkward as ak
from awkward import Array as AwkwardArray, Record as AwkwardRecord
from app.model import AlgoReport
from app.utils import Instant
from collections import deque
from dataclasses import dataclass
from aiomqtt import Client as MqttClient, Message as MqttMessage
from app.proto.hr_packet import HrPacket, HrOnlyPacket, HrPpgPacket, HrConfidence
import betterproto
MQTT_BROKER: Final[str] = "weihua-iot.cn"
MQTT_BROKER_PORT: Final[int] = 1883
MAX_LENGTH = 600
TOPIC: Final[str] = "/hr/region/1/band/#"
UDP_DEVICE_ID: Final[int] = 0xFF
NDArray = np.ndarray
T = TypeVar("T")
class DeviceHistory(TypedDict):
timescape: deque[datetime]
hr_data: deque[float]
# https://handmadesoftware.medium.com/streamlit-asyncio-and-mongodb-f85f77aea825
class AppState(TypedDict):
worker_thread: Thread
message_queue: MemoryObjectReceiveStream[Tuple[datetime, bytes]]
mqtt_message_queue: MemoryObjectReceiveStream[Tuple[datetime, MqttMessage]]
task_group: TaskGroup
device_histories: Dict[Union[int, str], DeviceHistory] # device_id -> DeviceHistory
BUSY_POLLING_INTERVAL_S: Final[float] = 0.001
BATCH_MESSAGE_INTERVAL_S: Final[float] = 0.1
NORMAL_REFRESH_INTERVAL_S: Final[float] = 0.5
QUEUE_BUFFER_SIZE: Final[int] = 32
UDP_SERVER_HOST: Final[str] = "localhost"
UDP_SERVER_PORT: Final[int] = 50_000
MAX_LENGTH = 600
NDArray = np.ndarray
T = TypeVar("T")
def unwrap(value: Optional[T]) -> T:
if value is None:
raise ValueError("Value is None")
return value
def parse_ble_hr_measurement(data: bytes) -> Optional[int]:
"""
Parse BLE Heart Rate Measurement characteristic data according to Bluetooth specification.
Args:
data: Raw bytes from the heart rate measurement characteristic
Returns:
Heart rate value in BPM, or None if parsing fails
Note:
```cpp
/**
* @brief Structure of the Heart Rate Measurement characteristic
*
* @see https://www.bluetooth.com/specifications/gss/
* @see section 3.116 Heart Rate Measurement of the document: GATT Specification Supplement.
*/
struct ble_hr_measurement_flag_t {
// LSB first
/*
* 0: uint8_t
* 1: uint16_t
*/
bool heart_rate_value_format : 1;
bool sensor_contact_detected : 1;
bool sensor_contact_supported : 1;
bool energy_expended_present : 1;
bool rr_interval_present : 1;
uint8_t reserved : 3;
};
```
"""
if len(data) < 2:
return None
# First byte contains flags
flags = data[0]
# Bit 0: Heart Rate Value Format (0 = uint8, 1 = uint16)
heart_rate_value_format = (flags & 0x01) != 0
try:
if heart_rate_value_format:
# 16-bit heart rate value (little endian)
if len(data) < 3:
return None
hr_value = int.from_bytes(data[1:3], byteorder="little")
else:
# 8-bit heart rate value
hr_value = data[1]
return hr_value
except (IndexError, ValueError):
return None
def hr_confidence_to_percentage(confidence: HrConfidence) -> float:
"""Convert HrConfidence enum to percentage value"""
if confidence == HrConfidence.ZERO:
return 0 # mid-point of [0,25)
elif confidence == HrConfidence.LOW:
return 37.5 # mid-point of [25,50)
elif confidence == HrConfidence.MEDIUM:
return 62.5 # mid-point of [50,75)
elif confidence == HrConfidence.HIGH:
return 100 # mid-point of (75,100]
else:
raise ValueError(f"Invalid HrConfidence: {confidence}")
def create_device_history() -> DeviceHistory:
"""Create a new device history structure"""
return {
"timescape": deque(maxlen=MAX_LENGTH),
"hr_data": deque(maxlen=MAX_LENGTH),
}
def get_device_name(device_id: Union[int, str]) -> str:
"""Get display name for device"""
if device_id == UDP_DEVICE_ID:
return "UDP Device"
return f"Device {device_id}"
@st.cache_resource
def resource(params: Any = None):
set_ev = anyio.Event()
tx, rx = create_memory_object_stream[Tuple[datetime, bytes]](
max_buffer_size=QUEUE_BUFFER_SIZE
)
mqtt_tx, mqtt_rx = create_memory_object_stream[Tuple[datetime, MqttMessage]](
max_buffer_size=QUEUE_BUFFER_SIZE
)
tg: Optional[TaskGroup] = None
async def udp_task():
async with await create_udp_socket(
local_host=UDP_SERVER_HOST, local_port=UDP_SERVER_PORT, reuse_port=True
) as udp:
logger.info(
"UDP server listening on {}:{}", UDP_SERVER_HOST, UDP_SERVER_PORT
)
async for packet, _ in udp:
timestamp = datetime.now()
await tx.send((timestamp, packet))
async def mqtt_task():
async with MqttClient(MQTT_BROKER, port=MQTT_BROKER_PORT) as client:
await client.subscribe(TOPIC)
logger.info(
"Subscribed to MQTT broker {}:{} topic {}",
MQTT_BROKER,
MQTT_BROKER_PORT,
TOPIC,
)
async for message in client.messages:
timestamp = datetime.now()
await mqtt_tx.send((timestamp, message))
async def combined_task():
nonlocal set_ev, tg
tg = anyio.create_task_group()
set_ev.set()
async with tg:
async with anyio.create_task_group() as inner_tg:
inner_tg.start_soon(udp_task)
inner_tg.start_soon(mqtt_task)
tr = Thread(target=anyio.run, args=(combined_task,))
tr.start()
while not set_ev.is_set():
sleep(BUSY_POLLING_INTERVAL_S)
logger.info("UDP and MQTT tasks initialized in single thread")
state: AppState = {
"worker_thread": tr,
"message_queue": rx,
"mqtt_message_queue": mqtt_rx,
"task_group": unwrap(tg),
"device_histories": {},
}
logger.info("Resource created")
return state
def main():
state = resource()
device_histories = state["device_histories"]
def on_export():
file_name = f"history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.parquet"
logger.info(f"Exporting to {file_name}")
# Export all device histories
export_data = {
device_id: ak.Record(dev_hist)
for device_id, dev_hist in device_histories.items()
}
rec = ak.Record(export_data)
ak.to_parquet(rec, file_name)
def on_clear():
nonlocal device_histories
logger.info("Clearing history")
device_histories.clear()
# https://docs.streamlit.io/develop/api-reference/layout
st.title("HR Visualizer")
with st.container(border=True):
c1, c2 = st.columns(2)
with c1:
st.button(
"Export",
help="Export the current data to a parquet file",
on_click=on_export,
)
with c2:
st.button(
"Clear",
help="Clear the current data",
on_click=on_clear,
)
# Device selection
if device_histories:
selected_devices = st.multiselect(
"Select devices to display:",
options=list(device_histories.keys()),
default=list(device_histories.keys()),
format_func=get_device_name,
)
else:
selected_devices = []
# Process UDP messages (treat as device_id = 0)
try:
while True:
timestamp, packet = state["message_queue"].receive_nowait()
hr_value = parse_ble_hr_measurement(packet)
if hr_value is not None:
now = timestamp
if UDP_DEVICE_ID not in device_histories:
device_histories[UDP_DEVICE_ID] = create_device_history()
dev_hist = device_histories[UDP_DEVICE_ID]
dev_hist["timescape"].append(now)
dev_hist["hr_data"].append(float(hr_value))
logger.debug("UDP Device: HR={}", hr_value)
except anyio.WouldBlock:
pass
# Process MQTT messages
try:
while True:
timestamp, mqtt_message = state["mqtt_message_queue"].receive_nowait()
if mqtt_message.payload:
try:
payload_bytes = mqtt_message.payload
if isinstance(payload_bytes, str):
payload_bytes = payload_bytes.encode("utf-8")
elif not isinstance(payload_bytes, bytes):
continue
hr_packet = HrPacket()
hr_packet.parse(payload_bytes)
now = timestamp
device_id = None
hr_value = None
if hr_packet.hr_only_packet:
packet = hr_packet.hr_only_packet
device_id = packet.id
hr_value = packet.hr
elif hr_packet.hr_ppg_packet:
packet = hr_packet.hr_ppg_packet
device_id = packet.id
hr_value = packet.hr
if device_id is not None and hr_value is not None:
if device_id not in device_histories:
device_histories[device_id] = create_device_history()
dev_hist = device_histories[device_id]
dev_hist["timescape"].append(now)
dev_hist["hr_data"].append(float(hr_value))
logger.debug("Device {}: HR={}", device_id, hr_value)
except (ValueError, TypeError, IndexError, UnicodeDecodeError) as e:
logger.error("MQTT protobuf message parsing: {}", e)
except anyio.WouldBlock:
pass
# Update visualization - HR Graphs
if device_histories:
st.subheader("Heart Rate Data")
# Create plots for selected devices
traces = []
colors = [
"red",
"green",
"blue",
"orange",
"purple",
"brown",
"pink",
"gray",
"olive",
"cyan",
]
for i, device_id in enumerate(selected_devices):
if device_id in device_histories:
dev_hist = device_histories[device_id]
if dev_hist["hr_data"] and dev_hist["timescape"]:
color = colors[i % len(colors)]
traces.append(
go.Scatter(
x=list(dev_hist["timescape"]),
y=list(dev_hist["hr_data"]),
mode="lines+markers",
name=get_device_name(device_id),
line=dict(color=color),
marker=dict(size=4),
)
)
if traces:
fig = go.Figure(data=traces)
fig.update_layout(
title="Heart Rate Monitor",
xaxis_title="Time",
yaxis_title="Heart Rate (BPM)",
hovermode="x unified",
showlegend=True,
height=500,
)
st.plotly_chart(fig, use_container_width=True)
else:
st.info("No heart rate data available for selected devices")
else:
st.info("No devices connected yet. Waiting for data...")
sleep(NORMAL_REFRESH_INTERVAL_S)
st.rerun()
if __name__ == "__main__":
main()