455 lines
16 KiB
C++
455 lines
16 KiB
C++
#include <algorithm>
|
|
#include <array>
|
|
#include <cstdint>
|
|
#include <stdexcept>
|
|
#include <vector>
|
|
|
|
#include <nanobind/nanobind.h>
|
|
#include <nanobind/ndarray.h>
|
|
#include <nanobind/stl/array.h>
|
|
#include <nanobind/stl/string.h>
|
|
#include <nanobind/stl/vector.h>
|
|
|
|
#include "interface.hpp"
|
|
|
|
namespace nb = nanobind;
|
|
using namespace nb::literals;
|
|
|
|
namespace
|
|
{
|
|
using PoseArray2D =
|
|
nb::ndarray<nb::numpy, const float, nb::shape<-1, -1, -1, 3>, nb::c_contig>;
|
|
using CountArray = nb::ndarray<nb::numpy, const uint32_t, nb::shape<-1>, nb::c_contig>;
|
|
using TrackIdArray = nb::ndarray<nb::numpy, const int64_t, nb::shape<-1>, nb::c_contig>;
|
|
using PoseArray3DConst =
|
|
nb::ndarray<nb::numpy, const float, nb::shape<-1, -1, 4>, nb::c_contig>;
|
|
using PoseArray3D = nb::ndarray<nb::numpy, float, nb::shape<-1, -1, 4>, nb::c_contig>;
|
|
using PoseArray2DOut = nb::ndarray<nb::numpy, float, nb::shape<-1, 4>, nb::c_contig>;
|
|
|
|
PoseBatch2DView pose_batch_view_from_numpy(const PoseArray2D &poses_2d, const CountArray &person_counts)
|
|
{
|
|
if (poses_2d.shape(0) != person_counts.shape(0))
|
|
{
|
|
throw std::invalid_argument("poses_2d and person_counts must have the same number of views.");
|
|
}
|
|
|
|
for (size_t i = 0; i < static_cast<size_t>(person_counts.shape(0)); ++i)
|
|
{
|
|
if (person_counts(i) > poses_2d.shape(1))
|
|
{
|
|
throw std::invalid_argument("person_counts entries must not exceed the padded person dimension.");
|
|
}
|
|
}
|
|
|
|
return PoseBatch2DView {
|
|
poses_2d.data(),
|
|
person_counts.data(),
|
|
static_cast<size_t>(poses_2d.shape(0)),
|
|
static_cast<size_t>(poses_2d.shape(1)),
|
|
static_cast<size_t>(poses_2d.shape(2)),
|
|
};
|
|
}
|
|
|
|
PoseBatch3DView pose_batch3d_view_from_numpy(const PoseArray3DConst &poses_3d)
|
|
{
|
|
return PoseBatch3DView {
|
|
poses_3d.data(),
|
|
static_cast<size_t>(poses_3d.shape(0)),
|
|
static_cast<size_t>(poses_3d.shape(1)),
|
|
};
|
|
}
|
|
|
|
TrackedPoseBatch3DView tracked_pose_batch_view_from_numpy(
|
|
const PoseArray3DConst &poses_3d,
|
|
const TrackIdArray &track_ids)
|
|
{
|
|
if (poses_3d.shape(0) != track_ids.shape(0))
|
|
{
|
|
throw std::invalid_argument(
|
|
"previous_poses_3d and previous_track_ids must have the same number of tracks.");
|
|
}
|
|
|
|
return TrackedPoseBatch3DView {
|
|
track_ids.data(),
|
|
poses_3d.data(),
|
|
static_cast<size_t>(poses_3d.shape(0)),
|
|
static_cast<size_t>(poses_3d.shape(1)),
|
|
};
|
|
}
|
|
|
|
PoseArray3D pose_batch_to_numpy(PoseBatch3D batch)
|
|
{
|
|
auto *storage = new std::vector<float>(std::move(batch.data));
|
|
nb::capsule owner(storage, [](void *value) noexcept
|
|
{
|
|
delete static_cast<std::vector<float> *>(value);
|
|
});
|
|
|
|
const size_t shape[3] = {batch.num_persons, batch.num_joints, 4};
|
|
return PoseArray3D(storage->data(), 3, shape, owner);
|
|
}
|
|
|
|
PoseArray3D pose_batch_to_numpy_copy(const PoseBatch3D &batch)
|
|
{
|
|
PoseBatch3D copy = batch;
|
|
return pose_batch_to_numpy(std::move(copy));
|
|
}
|
|
|
|
PoseArray2DOut pose_rows_to_numpy_copy(const std::vector<std::array<float, 4>> &rows)
|
|
{
|
|
auto *storage = new std::vector<float>(rows.size() * 4, 0.0f);
|
|
for (size_t row = 0; row < rows.size(); ++row)
|
|
{
|
|
for (size_t coord = 0; coord < 4; ++coord)
|
|
{
|
|
(*storage)[row * 4 + coord] = rows[row][coord];
|
|
}
|
|
}
|
|
|
|
nb::capsule owner(storage, [](void *value) noexcept
|
|
{
|
|
delete static_cast<std::vector<float> *>(value);
|
|
});
|
|
|
|
const size_t shape[2] = {rows.size(), 4};
|
|
return PoseArray2DOut(storage->data(), 2, shape, owner);
|
|
}
|
|
|
|
PoseArray3D merged_poses_to_numpy_copy(const std::vector<std::vector<std::array<float, 4>>> &poses)
|
|
{
|
|
size_t num_poses = poses.size();
|
|
size_t num_joints = 0;
|
|
if (!poses.empty())
|
|
{
|
|
num_joints = poses[0].size();
|
|
}
|
|
|
|
auto *storage = new std::vector<float>(num_poses * num_joints * 4, 0.0f);
|
|
for (size_t pose = 0; pose < num_poses; ++pose)
|
|
{
|
|
if (poses[pose].size() != num_joints)
|
|
{
|
|
delete storage;
|
|
throw std::invalid_argument("Merged poses must use a consistent joint count.");
|
|
}
|
|
for (size_t joint = 0; joint < num_joints; ++joint)
|
|
{
|
|
for (size_t coord = 0; coord < 4; ++coord)
|
|
{
|
|
(*storage)[((pose * num_joints) + joint) * 4 + coord] = poses[pose][joint][coord];
|
|
}
|
|
}
|
|
}
|
|
|
|
nb::capsule owner(storage, [](void *value) noexcept
|
|
{
|
|
delete static_cast<std::vector<float> *>(value);
|
|
});
|
|
|
|
const size_t shape[3] = {num_poses, num_joints, 4};
|
|
return PoseArray3D(storage->data(), 3, shape, owner);
|
|
}
|
|
} // namespace
|
|
|
|
NB_MODULE(_core, m)
|
|
{
|
|
nb::enum_<CameraModel>(m, "CameraModel")
|
|
.value("PINHOLE", CameraModel::Pinhole)
|
|
.value("FISHEYE", CameraModel::Fisheye);
|
|
|
|
nb::class_<Camera>(m, "Camera")
|
|
.def_prop_ro("name", [](const Camera &camera)
|
|
{
|
|
return camera.name;
|
|
})
|
|
.def_prop_ro("K", [](const Camera &camera)
|
|
{
|
|
return camera.K;
|
|
})
|
|
.def_prop_ro("DC", [](const Camera &camera)
|
|
{
|
|
return camera.DC;
|
|
})
|
|
.def_prop_ro("R", [](const Camera &camera)
|
|
{
|
|
return camera.R;
|
|
})
|
|
.def_prop_ro("T", [](const Camera &camera)
|
|
{
|
|
return camera.T;
|
|
})
|
|
.def_prop_ro("width", [](const Camera &camera)
|
|
{
|
|
return camera.width;
|
|
})
|
|
.def_prop_ro("height", [](const Camera &camera)
|
|
{
|
|
return camera.height;
|
|
})
|
|
.def_prop_ro("model", [](const Camera &camera)
|
|
{
|
|
return camera.model;
|
|
})
|
|
.def_prop_ro("invR", [](const Camera &camera)
|
|
{
|
|
return camera.invR;
|
|
})
|
|
.def_prop_ro("center", [](const Camera &camera)
|
|
{
|
|
return camera.center;
|
|
})
|
|
.def_prop_ro("newK", [](const Camera &camera)
|
|
{
|
|
return camera.newK;
|
|
})
|
|
.def_prop_ro("invK", [](const Camera &camera)
|
|
{
|
|
return camera.invK;
|
|
})
|
|
.def("__repr__", [](const Camera &camera)
|
|
{
|
|
return camera.to_string();
|
|
});
|
|
|
|
nb::class_<TriangulationOptions>(m, "TriangulationOptions")
|
|
.def(nb::init<>())
|
|
.def_rw("min_match_score", &TriangulationOptions::min_match_score)
|
|
.def_rw("min_group_size", &TriangulationOptions::min_group_size);
|
|
|
|
nb::class_<TriangulationConfig>(m, "TriangulationConfig")
|
|
.def(nb::init<>())
|
|
.def_rw("cameras", &TriangulationConfig::cameras)
|
|
.def_rw("roomparams", &TriangulationConfig::roomparams)
|
|
.def_rw("joint_names", &TriangulationConfig::joint_names)
|
|
.def_rw("options", &TriangulationConfig::options);
|
|
|
|
nb::class_<PairCandidate>(m, "PairCandidate")
|
|
.def(nb::init<>())
|
|
.def_rw("view1", &PairCandidate::view1)
|
|
.def_rw("view2", &PairCandidate::view2)
|
|
.def_rw("person1", &PairCandidate::person1)
|
|
.def_rw("person2", &PairCandidate::person2)
|
|
.def_rw("global_person1", &PairCandidate::global_person1)
|
|
.def_rw("global_person2", &PairCandidate::global_person2);
|
|
|
|
nb::class_<PreviousPoseMatch>(m, "PreviousPoseMatch")
|
|
.def(nb::init<>())
|
|
.def_rw("previous_pose_index", &PreviousPoseMatch::previous_pose_index)
|
|
.def_rw("previous_track_id", &PreviousPoseMatch::previous_track_id)
|
|
.def_rw("score_view1", &PreviousPoseMatch::score_view1)
|
|
.def_rw("score_view2", &PreviousPoseMatch::score_view2)
|
|
.def_rw("matched_view1", &PreviousPoseMatch::matched_view1)
|
|
.def_rw("matched_view2", &PreviousPoseMatch::matched_view2)
|
|
.def_rw("kept", &PreviousPoseMatch::kept)
|
|
.def_rw("decision", &PreviousPoseMatch::decision);
|
|
|
|
nb::class_<PreviousPoseFilterDebug>(m, "PreviousPoseFilterDebug")
|
|
.def(nb::init<>())
|
|
.def_rw("used_previous_poses", &PreviousPoseFilterDebug::used_previous_poses)
|
|
.def_rw("matches", &PreviousPoseFilterDebug::matches)
|
|
.def_rw("kept_pair_indices", &PreviousPoseFilterDebug::kept_pair_indices)
|
|
.def_rw("kept_pairs", &PreviousPoseFilterDebug::kept_pairs);
|
|
|
|
nb::class_<CoreProposalDebug>(m, "CoreProposalDebug")
|
|
.def(nb::init<>())
|
|
.def_rw("pair_index", &CoreProposalDebug::pair_index)
|
|
.def_rw("pair", &CoreProposalDebug::pair)
|
|
.def_rw("score", &CoreProposalDebug::score)
|
|
.def_rw("kept", &CoreProposalDebug::kept)
|
|
.def_rw("drop_reason", &CoreProposalDebug::drop_reason)
|
|
.def_prop_ro("pose_3d", [](const CoreProposalDebug &proposal)
|
|
{
|
|
return pose_rows_to_numpy_copy(proposal.pose_3d);
|
|
}, nb::rv_policy::move);
|
|
|
|
nb::class_<ProposalGroupDebug>(m, "ProposalGroupDebug")
|
|
.def(nb::init<>())
|
|
.def_rw("center", &ProposalGroupDebug::center)
|
|
.def_rw("proposal_indices", &ProposalGroupDebug::proposal_indices)
|
|
.def_prop_ro("pose_3d", [](const ProposalGroupDebug &group)
|
|
{
|
|
return pose_rows_to_numpy_copy(group.pose_3d);
|
|
}, nb::rv_policy::move);
|
|
|
|
nb::class_<GroupingDebug>(m, "GroupingDebug")
|
|
.def(nb::init<>())
|
|
.def_rw("initial_groups", &GroupingDebug::initial_groups)
|
|
.def_rw("duplicate_pair_drops", &GroupingDebug::duplicate_pair_drops)
|
|
.def_rw("groups", &GroupingDebug::groups);
|
|
|
|
nb::class_<FullProposalDebug>(m, "FullProposalDebug")
|
|
.def(nb::init<>())
|
|
.def_rw("source_core_proposal_index", &FullProposalDebug::source_core_proposal_index)
|
|
.def_rw("pair", &FullProposalDebug::pair)
|
|
.def_prop_ro("pose_3d", [](const FullProposalDebug &proposal)
|
|
{
|
|
return pose_rows_to_numpy_copy(proposal.pose_3d);
|
|
}, nb::rv_policy::move);
|
|
|
|
nb::class_<MergeDebug>(m, "MergeDebug")
|
|
.def(nb::init<>())
|
|
.def_rw("group_proposal_indices", &MergeDebug::group_proposal_indices)
|
|
.def_prop_ro("merged_poses", [](const MergeDebug &merge)
|
|
{
|
|
return merged_poses_to_numpy_copy(merge.merged_poses);
|
|
}, nb::rv_policy::move);
|
|
|
|
nb::enum_<AssociationStatus>(m, "AssociationStatus")
|
|
.value("MATCHED", AssociationStatus::Matched)
|
|
.value("NEW", AssociationStatus::New)
|
|
.value("AMBIGUOUS", AssociationStatus::Ambiguous);
|
|
|
|
nb::class_<AssociationReport>(m, "AssociationReport")
|
|
.def(nb::init<>())
|
|
.def_rw("pose_previous_indices", &AssociationReport::pose_previous_indices)
|
|
.def_rw("pose_previous_track_ids", &AssociationReport::pose_previous_track_ids)
|
|
.def_rw("pose_status", &AssociationReport::pose_status)
|
|
.def_rw("pose_candidate_previous_indices", &AssociationReport::pose_candidate_previous_indices)
|
|
.def_rw("pose_candidate_previous_track_ids", &AssociationReport::pose_candidate_previous_track_ids)
|
|
.def_rw("unmatched_previous_indices", &AssociationReport::unmatched_previous_indices)
|
|
.def_rw("unmatched_previous_track_ids", &AssociationReport::unmatched_previous_track_ids)
|
|
.def_rw("new_pose_indices", &AssociationReport::new_pose_indices)
|
|
.def_rw("ambiguous_pose_indices", &AssociationReport::ambiguous_pose_indices);
|
|
|
|
nb::class_<FinalPoseAssociationDebug>(m, "FinalPoseAssociationDebug")
|
|
.def(nb::init<>())
|
|
.def_rw("final_pose_index", &FinalPoseAssociationDebug::final_pose_index)
|
|
.def_rw("source_core_proposal_indices", &FinalPoseAssociationDebug::source_core_proposal_indices)
|
|
.def_rw("source_pair_indices", &FinalPoseAssociationDebug::source_pair_indices)
|
|
.def_rw("candidate_previous_indices", &FinalPoseAssociationDebug::candidate_previous_indices)
|
|
.def_rw("candidate_previous_track_ids", &FinalPoseAssociationDebug::candidate_previous_track_ids)
|
|
.def_rw("resolved_previous_index", &FinalPoseAssociationDebug::resolved_previous_index)
|
|
.def_rw("resolved_previous_track_id", &FinalPoseAssociationDebug::resolved_previous_track_id)
|
|
.def_rw("status", &FinalPoseAssociationDebug::status);
|
|
|
|
nb::class_<TriangulationTrace>(m, "TriangulationTrace")
|
|
.def(nb::init<>())
|
|
.def_rw("pairs", &TriangulationTrace::pairs)
|
|
.def_rw("previous_filter", &TriangulationTrace::previous_filter)
|
|
.def_rw("core_proposals", &TriangulationTrace::core_proposals)
|
|
.def_rw("grouping", &TriangulationTrace::grouping)
|
|
.def_rw("full_proposals", &TriangulationTrace::full_proposals)
|
|
.def_rw("merge", &TriangulationTrace::merge)
|
|
.def_rw("association", &TriangulationTrace::association)
|
|
.def_rw("final_pose_associations", &TriangulationTrace::final_pose_associations)
|
|
.def_prop_ro("final_poses", [](const TriangulationTrace &trace)
|
|
{
|
|
return pose_batch_to_numpy_copy(trace.final_poses);
|
|
}, nb::rv_policy::move);
|
|
|
|
nb::class_<TriangulationResult>(m, "TriangulationResult")
|
|
.def(nb::init<>())
|
|
.def_rw("association", &TriangulationResult::association)
|
|
.def_prop_ro("poses_3d", [](const TriangulationResult &result)
|
|
{
|
|
return pose_batch_to_numpy_copy(result.poses);
|
|
}, nb::rv_policy::move);
|
|
|
|
m.def(
|
|
"make_camera",
|
|
&make_camera,
|
|
"name"_a,
|
|
"K"_a,
|
|
"DC"_a,
|
|
"R"_a,
|
|
"T"_a,
|
|
"width"_a,
|
|
"height"_a,
|
|
"model"_a);
|
|
|
|
m.def(
|
|
"build_pair_candidates",
|
|
[](const PoseArray2D &poses_2d, const CountArray &person_counts)
|
|
{
|
|
return build_pair_candidates(pose_batch_view_from_numpy(poses_2d, person_counts));
|
|
},
|
|
"poses_2d"_a,
|
|
"person_counts"_a);
|
|
|
|
m.def(
|
|
"filter_pairs_with_previous_poses",
|
|
[](const PoseArray2D &poses_2d,
|
|
const CountArray &person_counts,
|
|
const TriangulationConfig &config,
|
|
const PoseArray3DConst &previous_poses_3d,
|
|
const TrackIdArray &previous_track_ids)
|
|
{
|
|
return filter_pairs_with_previous_poses(
|
|
pose_batch_view_from_numpy(poses_2d, person_counts),
|
|
config,
|
|
tracked_pose_batch_view_from_numpy(previous_poses_3d, previous_track_ids));
|
|
},
|
|
"poses_2d"_a,
|
|
"person_counts"_a,
|
|
"config"_a,
|
|
"previous_poses_3d"_a,
|
|
"previous_track_ids"_a);
|
|
|
|
m.def(
|
|
"triangulate_debug",
|
|
[](const PoseArray2D &poses_2d,
|
|
const CountArray &person_counts,
|
|
const TriangulationConfig &config)
|
|
{
|
|
return triangulate_debug(pose_batch_view_from_numpy(poses_2d, person_counts), config);
|
|
},
|
|
"poses_2d"_a,
|
|
"person_counts"_a,
|
|
"config"_a);
|
|
|
|
m.def(
|
|
"triangulate_debug",
|
|
[](const PoseArray2D &poses_2d,
|
|
const CountArray &person_counts,
|
|
const TriangulationConfig &config,
|
|
const PoseArray3DConst &previous_poses_3d,
|
|
const TrackIdArray &previous_track_ids)
|
|
{
|
|
const TrackedPoseBatch3DView previous_view =
|
|
tracked_pose_batch_view_from_numpy(previous_poses_3d, previous_track_ids);
|
|
return triangulate_debug(
|
|
pose_batch_view_from_numpy(poses_2d, person_counts),
|
|
config,
|
|
&previous_view);
|
|
},
|
|
"poses_2d"_a,
|
|
"person_counts"_a,
|
|
"config"_a,
|
|
"previous_poses_3d"_a,
|
|
"previous_track_ids"_a);
|
|
|
|
m.def(
|
|
"triangulate_poses",
|
|
[](const PoseArray2D &poses_2d,
|
|
const CountArray &person_counts,
|
|
const TriangulationConfig &config)
|
|
{
|
|
const PoseBatch3D poses_3d =
|
|
triangulate_poses(pose_batch_view_from_numpy(poses_2d, person_counts), config);
|
|
return pose_batch_to_numpy(poses_3d);
|
|
},
|
|
"poses_2d"_a,
|
|
"person_counts"_a,
|
|
"config"_a);
|
|
|
|
m.def(
|
|
"triangulate_with_report",
|
|
[](const PoseArray2D &poses_2d,
|
|
const CountArray &person_counts,
|
|
const TriangulationConfig &config,
|
|
const PoseArray3DConst &previous_poses_3d,
|
|
const TrackIdArray &previous_track_ids)
|
|
{
|
|
const TriangulationResult result = triangulate_with_report(
|
|
pose_batch_view_from_numpy(poses_2d, person_counts),
|
|
config,
|
|
tracked_pose_batch_view_from_numpy(previous_poses_3d, previous_track_ids));
|
|
return result;
|
|
},
|
|
"poses_2d"_a,
|
|
"person_counts"_a,
|
|
"config"_a,
|
|
"previous_poses_3d"_a,
|
|
"previous_track_ids"_a);
|
|
}
|