diff --git a/bindings/rpt_module.cpp b/bindings/rpt_module.cpp index 39aedc2..f1aba4d 100644 --- a/bindings/rpt_module.cpp +++ b/bindings/rpt_module.cpp @@ -20,6 +20,7 @@ namespace using PoseArray2D = nb::ndarray, nb::c_contig>; using CountArray = nb::ndarray, nb::c_contig>; +using TrackIdArray = nb::ndarray, nb::c_contig>; using PoseArray3DConst = nb::ndarray, nb::c_contig>; using PoseArray3D = nb::ndarray, nb::c_contig>; @@ -58,6 +59,24 @@ PoseBatch3DView pose_batch3d_view_from_numpy(const PoseArray3DConst &poses_3d) }; } +TrackedPoseBatch3DView tracked_pose_batch_view_from_numpy( + const PoseArray3DConst &poses_3d, + const TrackIdArray &track_ids) +{ + if (poses_3d.shape(0) != track_ids.shape(0)) + { + throw std::invalid_argument( + "previous_poses_3d and previous_track_ids must have the same number of tracks."); + } + + return TrackedPoseBatch3DView { + track_ids.data(), + poses_3d.data(), + static_cast(poses_3d.shape(0)), + static_cast(poses_3d.shape(1)), + }; +} + PoseArray3D pose_batch_to_numpy(PoseBatch3D batch) { auto *storage = new std::vector(std::move(batch.data)); @@ -216,6 +235,7 @@ NB_MODULE(_core, m) nb::class_(m, "PreviousPoseMatch") .def(nb::init<>()) .def_rw("previous_pose_index", &PreviousPoseMatch::previous_pose_index) + .def_rw("previous_track_id", &PreviousPoseMatch::previous_track_id) .def_rw("score_view1", &PreviousPoseMatch::score_view1) .def_rw("score_view2", &PreviousPoseMatch::score_view2) .def_rw("matched_view1", &PreviousPoseMatch::matched_view1) @@ -274,6 +294,34 @@ NB_MODULE(_core, m) return merged_poses_to_numpy_copy(merge.merged_poses); }, nb::rv_policy::move); + nb::enum_(m, "AssociationStatus") + .value("MATCHED", AssociationStatus::Matched) + .value("NEW", AssociationStatus::New) + .value("AMBIGUOUS", AssociationStatus::Ambiguous); + + nb::class_(m, "AssociationReport") + .def(nb::init<>()) + .def_rw("pose_previous_indices", &AssociationReport::pose_previous_indices) + .def_rw("pose_previous_track_ids", &AssociationReport::pose_previous_track_ids) + .def_rw("pose_status", &AssociationReport::pose_status) + .def_rw("pose_candidate_previous_indices", &AssociationReport::pose_candidate_previous_indices) + .def_rw("pose_candidate_previous_track_ids", &AssociationReport::pose_candidate_previous_track_ids) + .def_rw("unmatched_previous_indices", &AssociationReport::unmatched_previous_indices) + .def_rw("unmatched_previous_track_ids", &AssociationReport::unmatched_previous_track_ids) + .def_rw("new_pose_indices", &AssociationReport::new_pose_indices) + .def_rw("ambiguous_pose_indices", &AssociationReport::ambiguous_pose_indices); + + nb::class_(m, "FinalPoseAssociationDebug") + .def(nb::init<>()) + .def_rw("final_pose_index", &FinalPoseAssociationDebug::final_pose_index) + .def_rw("source_core_proposal_indices", &FinalPoseAssociationDebug::source_core_proposal_indices) + .def_rw("source_pair_indices", &FinalPoseAssociationDebug::source_pair_indices) + .def_rw("candidate_previous_indices", &FinalPoseAssociationDebug::candidate_previous_indices) + .def_rw("candidate_previous_track_ids", &FinalPoseAssociationDebug::candidate_previous_track_ids) + .def_rw("resolved_previous_index", &FinalPoseAssociationDebug::resolved_previous_index) + .def_rw("resolved_previous_track_id", &FinalPoseAssociationDebug::resolved_previous_track_id) + .def_rw("status", &FinalPoseAssociationDebug::status); + nb::class_(m, "TriangulationTrace") .def(nb::init<>()) .def_rw("pairs", &TriangulationTrace::pairs) @@ -282,11 +330,21 @@ NB_MODULE(_core, m) .def_rw("grouping", &TriangulationTrace::grouping) .def_rw("full_proposals", &TriangulationTrace::full_proposals) .def_rw("merge", &TriangulationTrace::merge) + .def_rw("association", &TriangulationTrace::association) + .def_rw("final_pose_associations", &TriangulationTrace::final_pose_associations) .def_prop_ro("final_poses", [](const TriangulationTrace &trace) { return pose_batch_to_numpy_copy(trace.final_poses); }, nb::rv_policy::move); + nb::class_(m, "TriangulationResult") + .def(nb::init<>()) + .def_rw("association", &TriangulationResult::association) + .def_prop_ro("poses_3d", [](const TriangulationResult &result) + { + return pose_batch_to_numpy_copy(result.poses); + }, nb::rv_policy::move); + m.def( "make_camera", &make_camera, @@ -313,17 +371,19 @@ NB_MODULE(_core, m) [](const PoseArray2D &poses_2d, const CountArray &person_counts, const TriangulationConfig &config, - const PoseArray3DConst &previous_poses_3d) + const PoseArray3DConst &previous_poses_3d, + const TrackIdArray &previous_track_ids) { return filter_pairs_with_previous_poses( pose_batch_view_from_numpy(poses_2d, person_counts), config, - pose_batch3d_view_from_numpy(previous_poses_3d)); + tracked_pose_batch_view_from_numpy(previous_poses_3d, previous_track_ids)); }, "poses_2d"_a, "person_counts"_a, "config"_a, - "previous_poses_3d"_a); + "previous_poses_3d"_a, + "previous_track_ids"_a); m.def( "triangulate_debug", @@ -342,9 +402,11 @@ NB_MODULE(_core, m) [](const PoseArray2D &poses_2d, const CountArray &person_counts, const TriangulationConfig &config, - const PoseArray3DConst &previous_poses_3d) + const PoseArray3DConst &previous_poses_3d, + const TrackIdArray &previous_track_ids) { - const PoseBatch3DView previous_view = pose_batch3d_view_from_numpy(previous_poses_3d); + const TrackedPoseBatch3DView previous_view = + tracked_pose_batch_view_from_numpy(previous_poses_3d, previous_track_ids); return triangulate_debug( pose_batch_view_from_numpy(poses_2d, person_counts), config, @@ -353,7 +415,8 @@ NB_MODULE(_core, m) "poses_2d"_a, "person_counts"_a, "config"_a, - "previous_poses_3d"_a); + "previous_poses_3d"_a, + "previous_track_ids"_a); m.def( "triangulate_poses", @@ -370,21 +433,22 @@ NB_MODULE(_core, m) "config"_a); m.def( - "triangulate_poses", + "triangulate_with_report", [](const PoseArray2D &poses_2d, const CountArray &person_counts, const TriangulationConfig &config, - const PoseArray3DConst &previous_poses_3d) + const PoseArray3DConst &previous_poses_3d, + const TrackIdArray &previous_track_ids) { - const PoseBatch3DView previous_view = pose_batch3d_view_from_numpy(previous_poses_3d); - const PoseBatch3D poses_3d = triangulate_poses( + const TriangulationResult result = triangulate_with_report( pose_batch_view_from_numpy(poses_2d, person_counts), config, - &previous_view); - return pose_batch_to_numpy(poses_3d); + tracked_pose_batch_view_from_numpy(previous_poses_3d, previous_track_ids)); + return result; }, "poses_2d"_a, "person_counts"_a, "config"_a, - "previous_poses_3d"_a); + "previous_poses_3d"_a, + "previous_track_ids"_a); } diff --git a/rpt/interface.cpp b/rpt/interface.cpp index 40d8ff6..bc12869 100644 --- a/rpt/interface.cpp +++ b/rpt/interface.cpp @@ -43,6 +43,16 @@ const float &PoseBatch3DView::at(size_t person, size_t joint, size_t coord) cons return data[pose3d_offset(person, joint, coord, num_joints)]; } +int64_t TrackedPoseBatch3DView::track_id(size_t person) const +{ + return track_ids[person]; +} + +const float &TrackedPoseBatch3DView::at(size_t person, size_t joint, size_t coord) const +{ + return data[pose3d_offset(person, joint, coord, num_joints)]; +} + const float &PoseBatch2D::at(size_t view, size_t person, size_t joint, size_t coord) const { return data[pose2d_offset(view, person, joint, coord, max_persons, num_joints)]; diff --git a/rpt/interface.hpp b/rpt/interface.hpp index a5b3f8f..422ee39 100644 --- a/rpt/interface.hpp +++ b/rpt/interface.hpp @@ -34,6 +34,17 @@ struct PoseBatch3DView const float &at(size_t person, size_t joint, size_t coord) const; }; +struct TrackedPoseBatch3DView +{ + const int64_t *track_ids = nullptr; + const float *data = nullptr; + size_t num_persons = 0; + size_t num_joints = 0; + + int64_t track_id(size_t person) const; + const float &at(size_t person, size_t joint, size_t coord) const; +}; + struct PoseBatch2D { std::vector data; @@ -78,6 +89,7 @@ struct PairCandidate struct PreviousPoseMatch { int previous_pose_index = -1; + int64_t previous_track_id = -1; float score_view1 = 0.0f; float score_view2 = 0.0f; bool matched_view1 = false; @@ -131,6 +143,38 @@ struct MergeDebug std::vector> group_proposal_indices; }; +enum class AssociationStatus +{ + Matched, + New, + Ambiguous, +}; + +struct AssociationReport +{ + std::vector pose_previous_indices; + std::vector pose_previous_track_ids; + std::vector pose_status; + std::vector> pose_candidate_previous_indices; + std::vector> pose_candidate_previous_track_ids; + std::vector unmatched_previous_indices; + std::vector unmatched_previous_track_ids; + std::vector new_pose_indices; + std::vector ambiguous_pose_indices; +}; + +struct FinalPoseAssociationDebug +{ + int final_pose_index = -1; + std::vector source_core_proposal_indices; + std::vector source_pair_indices; + std::vector candidate_previous_indices; + std::vector candidate_previous_track_ids; + int resolved_previous_index = -1; + int64_t resolved_previous_track_id = -1; + AssociationStatus status = AssociationStatus::New; +}; + struct TriangulationTrace { std::vector pairs; @@ -139,9 +183,17 @@ struct TriangulationTrace GroupingDebug grouping; std::vector full_proposals; MergeDebug merge; + AssociationReport association; + std::vector final_pose_associations; PoseBatch3D final_poses; }; +struct TriangulationResult +{ + PoseBatch3D poses; + AssociationReport association; +}; + // ================================================================================================= struct TriangulationOptions @@ -165,40 +217,15 @@ std::vector build_pair_candidates(const PoseBatch2DView &poses_2d PreviousPoseFilterDebug filter_pairs_with_previous_poses( const PoseBatch2DView &poses_2d, const TriangulationConfig &config, - const PoseBatch3DView &previous_poses_3d, + const TrackedPoseBatch3DView &previous_poses_3d, const TriangulationOptions *options_override = nullptr); -inline PreviousPoseFilterDebug filter_pairs_with_previous_poses( - const PoseBatch2D &poses_2d, - const TriangulationConfig &config, - const PoseBatch3D &previous_poses_3d, - const TriangulationOptions *options_override = nullptr) -{ - return filter_pairs_with_previous_poses(poses_2d.view(), config, previous_poses_3d.view(), options_override); -} - TriangulationTrace triangulate_debug( const PoseBatch2DView &poses_2d, const TriangulationConfig &config, - const PoseBatch3DView *previous_poses_3d = nullptr, + const TrackedPoseBatch3DView *previous_poses_3d = nullptr, const TriangulationOptions *options_override = nullptr); -inline TriangulationTrace triangulate_debug( - const PoseBatch2D &poses_2d, - const TriangulationConfig &config, - const PoseBatch3D *previous_poses_3d = nullptr, - const TriangulationOptions *options_override = nullptr) -{ - PoseBatch3DView previous_view_storage; - const PoseBatch3DView *previous_view = nullptr; - if (previous_poses_3d != nullptr) - { - previous_view_storage = previous_poses_3d->view(); - previous_view = &previous_view_storage; - } - return triangulate_debug(poses_2d.view(), config, previous_view, options_override); -} - // ================================================================================================= /** @@ -213,21 +240,35 @@ inline TriangulationTrace triangulate_debug( PoseBatch3D triangulate_poses( const PoseBatch2DView &poses_2d, const TriangulationConfig &config, - const PoseBatch3DView *previous_poses_3d = nullptr, + const TriangulationOptions *options_override = nullptr); + +TriangulationResult triangulate_with_report( + const PoseBatch2DView &poses_2d, + const TriangulationConfig &config, + const TrackedPoseBatch3DView &previous_poses_3d, const TriangulationOptions *options_override = nullptr); inline PoseBatch3D triangulate_poses( const PoseBatch2D &poses_2d, const TriangulationConfig &config, - const PoseBatch3D *previous_poses_3d = nullptr, const TriangulationOptions *options_override = nullptr) { - PoseBatch3DView previous_view_storage; - const PoseBatch3DView *previous_view = nullptr; - if (previous_poses_3d != nullptr) - { - previous_view_storage = previous_poses_3d->view(); - previous_view = &previous_view_storage; - } - return triangulate_poses(poses_2d.view(), config, previous_view, options_override); + return triangulate_poses(poses_2d.view(), config, options_override); +} + +inline TriangulationTrace triangulate_debug( + const PoseBatch2D &poses_2d, + const TriangulationConfig &config, + const TriangulationOptions *options_override = nullptr) +{ + return triangulate_debug(poses_2d.view(), config, nullptr, options_override); +} + +inline TriangulationResult triangulate_with_report( + const PoseBatch2D &poses_2d, + const TriangulationConfig &config, + const TrackedPoseBatch3DView &previous_poses_3d, + const TriangulationOptions *options_override = nullptr) +{ + return triangulate_with_report(poses_2d.view(), config, previous_poses_3d, options_override); } diff --git a/rpt/triangulator.cpp b/rpt/triangulator.cpp index 077bd8c..6ce9962 100644 --- a/rpt/triangulator.cpp +++ b/rpt/triangulator.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "interface.hpp" @@ -190,6 +191,15 @@ struct PreviousProjectionCache std::vector core_poses; }; +struct FinalAssociationState +{ + std::vector candidate_previous_indices; + std::vector candidate_previous_track_ids; + int resolved_previous_index = -1; + int64_t resolved_previous_track_id = -1; + AssociationStatus status = AssociationStatus::New; +}; + constexpr std::array kCoreJoints = { "shoulder_left", "shoulder_right", @@ -218,7 +228,7 @@ constexpr std::array, 8> kCoreLimb std::vector build_pair_candidates_from_packed(const PackedPoseStore2D &packed_poses); PreviousProjectionCache project_previous_poses( - const PoseBatch3DView &previous_poses_3d, + const TrackedPoseBatch3DView &previous_poses_3d, const std::vector &internal_cameras, const std::vector &core_joint_idx); PreviousPoseFilterDebug filter_pairs_with_previous_poses_impl( @@ -227,7 +237,7 @@ PreviousPoseFilterDebug filter_pairs_with_previous_poses_impl( const std::vector &core_joint_idx, const std::vector &pairs, const TriangulationOptions &options, - const PoseBatch3DView &previous_poses_3d); + const TrackedPoseBatch3DView &previous_poses_3d); float calc_pose_score( const Pose2D &pose, const Pose2D &reference_pose, @@ -279,7 +289,7 @@ TriangulationTrace triangulate_debug_impl( const std::vector &cameras, const std::array, 2> &roomparams, const std::vector &joint_names, - const PoseBatch3DView *previous_poses_3d, + const TrackedPoseBatch3DView *previous_poses_3d, const TriangulationOptions &options); PreparedInputs prepare_inputs( @@ -346,6 +356,39 @@ PreparedInputs prepare_inputs( std::move(packed_poses)); } +void validate_previous_tracks( + const TrackedPoseBatch3DView &previous_poses_3d, + const std::vector &joint_names) +{ + if (previous_poses_3d.num_persons == 0) + { + return; + } + if (previous_poses_3d.track_ids == nullptr) + { + throw std::invalid_argument("previous_track_ids must not be null."); + } + if (previous_poses_3d.data == nullptr) + { + throw std::invalid_argument("previous_poses_3d data must not be null."); + } + if (previous_poses_3d.num_joints != joint_names.size()) + { + throw std::invalid_argument("previous_poses_3d must use the same joint count as joint_names."); + } + + std::unordered_set seen_track_ids; + seen_track_ids.reserve(previous_poses_3d.num_persons); + for (size_t person = 0; person < previous_poses_3d.num_persons; ++person) + { + const int64_t track_id = previous_poses_3d.track_id(person); + if (!seen_track_ids.insert(track_id).second) + { + throw std::invalid_argument("previous_track_ids must be unique."); + } + } +} + std::vector build_pair_candidates_from_packed(const PackedPoseStore2D &packed_poses) { std::vector pairs; @@ -373,7 +416,7 @@ std::vector build_pair_candidates_from_packed(const PackedPoseSto } PreviousProjectionCache project_previous_poses( - const PoseBatch3DView &previous_poses_3d, + const TrackedPoseBatch3DView &previous_poses_3d, const std::vector &internal_cameras, const std::vector &core_joint_idx) { @@ -449,7 +492,7 @@ PreviousPoseFilterDebug filter_pairs_with_previous_poses_impl( const std::vector &core_joint_idx, const std::vector &pairs, const TriangulationOptions &options, - const PoseBatch3DView &previous_poses_3d) + const TrackedPoseBatch3DView &previous_poses_3d) { PreviousPoseFilterDebug debug; debug.used_previous_poses = true; @@ -500,6 +543,7 @@ PreviousPoseFilterDebug filter_pairs_with_previous_poses_impl( { best_match = PreviousPoseMatch { static_cast(previous_index), + previous_poses_3d.track_id(previous_index), score_view1, score_view2, true, @@ -515,6 +559,7 @@ PreviousPoseFilterDebug filter_pairs_with_previous_poses_impl( { best_match = PreviousPoseMatch { static_cast(previous_index), + previous_poses_3d.track_id(previous_index), score_view1, score_view2, matched_view1, @@ -549,16 +594,15 @@ TriangulationTrace triangulate_debug_impl( const std::vector &cameras, const std::array, 2> &roomparams, const std::vector &joint_names, - const PoseBatch3DView *previous_poses_3d, + const TrackedPoseBatch3DView *previous_poses_3d, const TriangulationOptions &options) { TriangulationTrace trace; trace.final_poses.num_joints = joint_names.size(); - if (previous_poses_3d != nullptr && previous_poses_3d->num_persons > 0 && - previous_poses_3d->num_joints != joint_names.size()) + if (previous_poses_3d != nullptr) { - throw std::invalid_argument("previous_poses_3d must use the same joint count as joint_names."); + validate_previous_tracks(*previous_poses_3d, joint_names); } PreparedInputs prepared = prepare_inputs(poses_2d, cameras, joint_names); @@ -831,6 +875,63 @@ TriangulationTrace triangulate_debug_impl( trace.merge.merged_poses.push_back(final_poses_3d[i]); } + std::vector group_associations(groups.size()); + if (previous_poses_3d != nullptr) + { + for (size_t group_index = 0; group_index < trace.merge.group_proposal_indices.size(); ++group_index) + { + FinalAssociationState association; + std::set candidate_previous_indices; + std::set candidate_previous_track_ids; + + for (const int core_index : trace.merge.group_proposal_indices[group_index]) + { + if (core_index < 0 || static_cast(core_index) >= trace.core_proposals.size()) + { + continue; + } + + const int pair_index = trace.core_proposals[static_cast(core_index)].pair_index; + if (pair_index < 0 || static_cast(pair_index) >= trace.previous_filter.matches.size()) + { + continue; + } + + const PreviousPoseMatch &match = trace.previous_filter.matches[static_cast(pair_index)]; + if (!match.kept || match.previous_pose_index < 0 || match.previous_track_id < 0) + { + continue; + } + + candidate_previous_indices.insert(match.previous_pose_index); + candidate_previous_track_ids.insert(match.previous_track_id); + } + + association.candidate_previous_indices.assign( + candidate_previous_indices.begin(), candidate_previous_indices.end()); + association.candidate_previous_track_ids.assign( + candidate_previous_track_ids.begin(), candidate_previous_track_ids.end()); + + if (association.candidate_previous_track_ids.empty()) + { + association.status = AssociationStatus::New; + } + else if (association.candidate_previous_track_ids.size() == 1 && + association.candidate_previous_indices.size() == 1) + { + association.status = AssociationStatus::Matched; + association.resolved_previous_index = association.candidate_previous_indices.front(); + association.resolved_previous_track_id = association.candidate_previous_track_ids.front(); + } + else + { + association.status = AssociationStatus::Ambiguous; + } + + group_associations[group_index] = std::move(association); + } + } + add_extra_joints(final_poses_3d, joint_names); filter_poses(final_poses_3d, roomparams, prepared.core_joint_idx, prepared.core_limbs_idx); add_missing_joints(final_poses_3d, joint_names, options.min_match_score); @@ -854,10 +955,20 @@ TriangulationTrace triangulate_debug_impl( trace.final_poses.num_persons = valid_persons; trace.final_poses.data.resize(valid_persons * trace.final_poses.num_joints * 4); + trace.association.pose_previous_indices.reserve(valid_persons); + trace.association.pose_previous_track_ids.reserve(valid_persons); + trace.association.pose_status.reserve(valid_persons); + trace.association.pose_candidate_previous_indices.reserve(valid_persons); + trace.association.pose_candidate_previous_track_ids.reserve(valid_persons); + trace.final_pose_associations.reserve(valid_persons); + + std::set resolved_previous_indices; + std::set resolved_previous_track_ids; size_t person_index = 0; - for (const auto &pose : final_poses_3d) + for (size_t group_index = 0; group_index < final_poses_3d.size(); ++group_index) { + const auto &pose = final_poses_3d[group_index]; const bool is_valid = std::any_of( pose.begin(), pose.end(), @@ -877,9 +988,70 @@ TriangulationTrace triangulate_debug_impl( trace.final_poses.at(person_index, joint, coord) = pose[joint][coord]; } } + + if (previous_poses_3d != nullptr) + { + const FinalAssociationState &association = group_associations[group_index]; + trace.association.pose_previous_indices.push_back(association.resolved_previous_index); + trace.association.pose_previous_track_ids.push_back(association.resolved_previous_track_id); + trace.association.pose_status.push_back(association.status); + trace.association.pose_candidate_previous_indices.push_back(association.candidate_previous_indices); + trace.association.pose_candidate_previous_track_ids.push_back( + association.candidate_previous_track_ids); + + if (association.status == AssociationStatus::Matched) + { + resolved_previous_indices.insert(association.resolved_previous_index); + resolved_previous_track_ids.insert(association.resolved_previous_track_id); + } + else if (association.status == AssociationStatus::New) + { + trace.association.new_pose_indices.push_back(static_cast(person_index)); + } + else if (association.status == AssociationStatus::Ambiguous) + { + trace.association.ambiguous_pose_indices.push_back(static_cast(person_index)); + } + + FinalPoseAssociationDebug debug_association; + debug_association.final_pose_index = static_cast(person_index); + debug_association.source_core_proposal_indices = + trace.merge.group_proposal_indices[group_index]; + debug_association.candidate_previous_indices = association.candidate_previous_indices; + debug_association.candidate_previous_track_ids = association.candidate_previous_track_ids; + debug_association.resolved_previous_index = association.resolved_previous_index; + debug_association.resolved_previous_track_id = association.resolved_previous_track_id; + debug_association.status = association.status; + debug_association.source_pair_indices.reserve(debug_association.source_core_proposal_indices.size()); + for (const int core_index : debug_association.source_core_proposal_indices) + { + if (core_index >= 0 && static_cast(core_index) < trace.core_proposals.size()) + { + debug_association.source_pair_indices.push_back( + trace.core_proposals[static_cast(core_index)].pair_index); + } + } + trace.final_pose_associations.push_back(std::move(debug_association)); + } + ++person_index; } + if (previous_poses_3d != nullptr) + { + for (size_t previous_index = 0; previous_index < previous_poses_3d->num_persons; ++previous_index) + { + const int previous_index_int = static_cast(previous_index); + const int64_t track_id = previous_poses_3d->track_id(previous_index); + if (!resolved_previous_indices.contains(previous_index_int) && + !resolved_previous_track_ids.contains(track_id)) + { + trace.association.unmatched_previous_indices.push_back(previous_index_int); + trace.association.unmatched_previous_track_ids.push_back(track_id); + } + } + } + return trace; } @@ -887,13 +1059,10 @@ PreviousPoseFilterDebug filter_pairs_with_previous_poses_impl( const PoseBatch2DView &poses_2d, const std::vector &cameras, const std::vector &joint_names, - const PoseBatch3DView &previous_poses_3d, + const TrackedPoseBatch3DView &previous_poses_3d, const TriangulationOptions &options) { - if (previous_poses_3d.num_persons > 0 && previous_poses_3d.num_joints != joint_names.size()) - { - throw std::invalid_argument("previous_poses_3d must use the same joint count as joint_names."); - } + validate_previous_tracks(previous_poses_3d, joint_names); PreparedInputs prepared = prepare_inputs(poses_2d, cameras, joint_names); const std::vector pairs = build_pair_candidates_from_packed(prepared.packed_poses); @@ -2350,7 +2519,7 @@ std::vector build_pair_candidates(const PoseBatch2DView &poses_2d PreviousPoseFilterDebug filter_pairs_with_previous_poses( const PoseBatch2DView &poses_2d, const TriangulationConfig &config, - const PoseBatch3DView &previous_poses_3d, + const TrackedPoseBatch3DView &previous_poses_3d, const TriangulationOptions *options_override) { const TriangulationOptions &options = @@ -2362,7 +2531,7 @@ PreviousPoseFilterDebug filter_pairs_with_previous_poses( TriangulationTrace triangulate_debug( const PoseBatch2DView &poses_2d, const TriangulationConfig &config, - const PoseBatch3DView *previous_poses_3d, + const TrackedPoseBatch3DView *previous_poses_3d, const TriangulationOptions *options_override) { const TriangulationOptions &options = @@ -2374,8 +2543,20 @@ TriangulationTrace triangulate_debug( PoseBatch3D triangulate_poses( const PoseBatch2DView &poses_2d, const TriangulationConfig &config, - const PoseBatch3DView *previous_poses_3d, const TriangulationOptions *options_override) { - return triangulate_debug(poses_2d, config, previous_poses_3d, options_override).final_poses; + return triangulate_debug(poses_2d, config, nullptr, options_override).final_poses; +} + +TriangulationResult triangulate_with_report( + const PoseBatch2DView &poses_2d, + const TriangulationConfig &config, + const TrackedPoseBatch3DView &previous_poses_3d, + const TriangulationOptions *options_override) +{ + TriangulationTrace trace = triangulate_debug(poses_2d, config, &previous_poses_3d, options_override); + return TriangulationResult { + std::move(trace.final_poses), + std::move(trace.association), + }; } diff --git a/src/rpt/__init__.py b/src/rpt/__init__.py index f8560bb..92c4168 100644 --- a/src/rpt/__init__.py +++ b/src/rpt/__init__.py @@ -4,10 +4,14 @@ from collections.abc import Sequence from typing import TYPE_CHECKING from ._core import ( + AssociationReport, + AssociationStatus, Camera, CameraModel, + FinalPoseAssociationDebug, TriangulationConfig, TriangulationOptions, + TriangulationResult, CoreProposalDebug, FullProposalDebug, GroupingDebug, @@ -22,6 +26,7 @@ from ._core import ( make_camera, triangulate_debug, triangulate_poses, + triangulate_with_report, ) if TYPE_CHECKING: @@ -67,8 +72,12 @@ def make_triangulation_config( __all__ = [ "Camera", "CameraModel", + "AssociationReport", + "AssociationStatus", + "FinalPoseAssociationDebug", "TriangulationConfig", "TriangulationOptions", + "TriangulationResult", "CoreProposalDebug", "FullProposalDebug", "GroupingDebug", @@ -86,4 +95,5 @@ __all__ = [ "pack_poses_2d", "triangulate_debug", "triangulate_poses", + "triangulate_with_report", ] diff --git a/tests/test_interface.py b/tests/test_interface.py index 27c0e9a..393be20 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -119,11 +119,19 @@ def test_triangulate_accepts_empty_previous_poses(): poses_2d, person_counts, cameras = load_case("data/p1/sample.json", "tests/poses_p1.json") config = make_config(cameras, [[5.6, 6.4, 2.4], [0.0, -0.5, 1.2]]) empty_previous = np.zeros((0, len(JOINT_NAMES), 4), dtype=np.float32) + empty_previous_ids = np.zeros((0,), dtype=np.int64) baseline = rpt.triangulate_poses(poses_2d, person_counts, config) - with_previous = rpt.triangulate_poses(poses_2d, person_counts, config, empty_previous) + result = rpt.triangulate_with_report( + poses_2d, + person_counts, + config, + empty_previous, + empty_previous_ids, + ) - np.testing.assert_allclose(with_previous, baseline, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(result.poses_3d, baseline, rtol=1e-5, atol=1e-5) + assert result.association.unmatched_previous_track_ids == [] def test_triangulate_debug_matches_final_output(): @@ -135,6 +143,7 @@ def test_triangulate_debug_matches_final_output(): np.testing.assert_allclose(trace.final_poses, final_poses, rtol=1e-5, atol=1e-5) assert len(trace.pairs) >= len(trace.core_proposals) + assert trace.association.pose_previous_track_ids == [] for group in trace.grouping.groups: assert all(0 <= index < len(trace.core_proposals) for index in group.proposal_indices) for merge_indices in trace.merge.group_proposal_indices: @@ -145,18 +154,67 @@ def test_filter_pairs_with_previous_poses_returns_debug_matches(): poses_2d, person_counts, cameras = load_case("data/p1/sample.json", "tests/poses_p1.json") config = make_config(cameras, [[5.6, 6.4, 2.4], [0.0, -0.5, 1.2]]) previous_poses = rpt.triangulate_poses(poses_2d, person_counts, config) + previous_track_ids = np.arange(previous_poses.shape[0], dtype=np.int64) + 100 debug = rpt.filter_pairs_with_previous_poses( poses_2d, person_counts, config, previous_poses, + previous_track_ids, ) assert debug.used_previous_poses is True assert len(debug.matches) == len(rpt.build_pair_candidates(poses_2d, person_counts)) assert len(debug.kept_pairs) == len(debug.kept_pair_indices) assert any(match.matched_view1 or match.matched_view2 for match in debug.matches) + assert any(match.previous_track_id >= 100 for match in debug.matches if match.previous_pose_index >= 0) + + +def test_triangulate_with_report_resolves_previous_track_ids(): + poses_2d, person_counts, cameras = load_case("data/p1/sample.json", "tests/poses_p1.json") + config = make_config(cameras, [[5.6, 6.4, 2.4], [0.0, -0.5, 1.2]]) + previous_poses = rpt.triangulate_poses(poses_2d, person_counts, config) + previous_track_ids = np.arange(previous_poses.shape[0], dtype=np.int64) + 100 + + result = rpt.triangulate_with_report( + poses_2d, + person_counts, + config, + previous_poses, + previous_track_ids, + ) + + assert result.poses_3d.shape == previous_poses.shape + assert len(result.association.pose_previous_track_ids) == result.poses_3d.shape[0] + matched_track_ids = sorted( + track_id for track_id in result.association.pose_previous_track_ids if track_id >= 0 + ) + unmatched_track_ids = sorted(result.association.unmatched_previous_track_ids) + + for pose_index in result.association.new_pose_indices: + assert result.association.pose_previous_track_ids[pose_index] == -1 + for pose_index in result.association.ambiguous_pose_indices: + assert result.association.pose_previous_track_ids[pose_index] == -1 + + assert matched_track_ids == sorted(set(matched_track_ids)) + assert sorted(matched_track_ids + unmatched_track_ids) == list(previous_track_ids) + + +def test_triangulate_with_report_rejects_duplicate_previous_track_ids(): + poses_2d, person_counts, cameras = load_case("data/p1/sample.json", "tests/poses_p1.json") + config = make_config(cameras, [[5.6, 6.4, 2.4], [0.0, -0.5, 1.2]]) + previous_poses = rpt.triangulate_poses(poses_2d, person_counts, config) + duplicate_ids = np.zeros((previous_poses.shape[0],), dtype=np.int64) + + with pytest.raises(ValueError, match="unique"): + rpt.triangulate_with_report( + poses_2d, + person_counts, + config, + previous_poses, + duplicate_ids, + ) def test_triangulate_does_not_mutate_inputs():