#include #include #include #include "interface.hpp" // ================================================================================================= namespace { size_t pose2d_offset( size_t view, size_t person, size_t joint, size_t coord, size_t max_persons, size_t num_joints) { return ((((view * max_persons) + person) * num_joints) + joint) * 3 + coord; } size_t pose3d_offset(size_t person, size_t joint, size_t coord, size_t num_joints) { return (((person * num_joints) + joint) * 4) + coord; } } // namespace // ================================================================================================= // ================================================================================================= float &PoseBatch2D::at(size_t view, size_t person, size_t joint, size_t coord) { return data[pose2d_offset(view, person, joint, coord, max_persons, num_joints)]; } const float &PoseBatch2DView::at(size_t view, size_t person, size_t joint, size_t coord) const { return data[pose2d_offset(view, person, joint, coord, max_persons, num_joints)]; } const float &PoseBatch3DView::at(size_t person, size_t joint, size_t coord) const { return data[pose3d_offset(person, joint, coord, num_joints)]; } const float &PoseBatch2D::at(size_t view, size_t person, size_t joint, size_t coord) const { return data[pose2d_offset(view, person, joint, coord, max_persons, num_joints)]; } PoseBatch2DView PoseBatch2D::view() const { return PoseBatch2DView {data.data(), person_counts.data(), num_views, max_persons, num_joints}; } PoseBatch2D PoseBatch2D::from_nested(const RaggedPoses2D &poses_2d) { PoseBatch2D batch; batch.num_views = poses_2d.size(); for (const auto &view : poses_2d) { batch.max_persons = std::max(batch.max_persons, view.size()); if (!view.empty()) { if (batch.num_joints == 0) { batch.num_joints = view[0].size(); } else if (batch.num_joints != view[0].size()) { throw std::invalid_argument("All views must use the same joint count."); } for (const auto &person : view) { if (person.size() != batch.num_joints) { throw std::invalid_argument("All persons must use the same joint count."); } } } } batch.person_counts.resize(batch.num_views); batch.data.assign(batch.num_views * batch.max_persons * batch.num_joints * 3, 0.0f); for (size_t view = 0; view < batch.num_views; ++view) { batch.person_counts[view] = static_cast(poses_2d[view].size()); for (size_t person = 0; person < poses_2d[view].size(); ++person) { for (size_t joint = 0; joint < batch.num_joints; ++joint) { for (size_t coord = 0; coord < 3; ++coord) { batch.at(view, person, joint, coord) = poses_2d[view][person][joint][coord]; } } } } return batch; } // ================================================================================================= float &PoseBatch3D::at(size_t person, size_t joint, size_t coord) { return data[pose3d_offset(person, joint, coord, num_joints)]; } const float &PoseBatch3D::at(size_t person, size_t joint, size_t coord) const { return data[pose3d_offset(person, joint, coord, num_joints)]; } PoseBatch3DView PoseBatch3D::view() const { return PoseBatch3DView {data.data(), num_persons, num_joints}; } NestedPoses3D PoseBatch3D::to_nested() const { NestedPoses3D poses_3d(num_persons); for (size_t person = 0; person < num_persons; ++person) { poses_3d[person].resize(num_joints); for (size_t joint = 0; joint < num_joints; ++joint) { for (size_t coord = 0; coord < 4; ++coord) { poses_3d[person][joint][coord] = at(person, joint, coord); } } } return poses_3d; } PoseBatch3D PoseBatch3D::from_nested(const NestedPoses3D &poses_3d) { PoseBatch3D batch; batch.num_persons = poses_3d.size(); if (!poses_3d.empty()) { batch.num_joints = poses_3d[0].size(); } batch.data.resize(batch.num_persons * batch.num_joints * 4); for (size_t person = 0; person < batch.num_persons; ++person) { if (poses_3d[person].size() != batch.num_joints) { throw std::invalid_argument("All 3D poses must use the same joint count."); } for (size_t joint = 0; joint < batch.num_joints; ++joint) { for (size_t coord = 0; coord < 4; ++coord) { batch.at(person, joint, coord) = poses_3d[person][joint][coord]; } } } return batch; }