Migrate Python bindings from SWIG to nanobind
This commit is contained in:
@@ -0,0 +1,43 @@
|
||||
find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)
|
||||
|
||||
if(NOT nanobind_DIR)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -c "import nanobind; print(nanobind.cmake_dir())"
|
||||
OUTPUT_VARIABLE nanobind_DIR
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
RESULT_VARIABLE nanobind_dir_result
|
||||
)
|
||||
if(NOT nanobind_dir_result EQUAL 0 OR NOT nanobind_DIR)
|
||||
message(FATAL_ERROR "Failed to resolve nanobind CMake directory from Python.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
set(RPT_PYTHON_PACKAGE_DIR "${CMAKE_CURRENT_BINARY_DIR}/rpt")
|
||||
file(MAKE_DIRECTORY "${RPT_PYTHON_PACKAGE_DIR}")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/src/rpt/__init__.py" "${RPT_PYTHON_PACKAGE_DIR}/__init__.py" COPYONLY)
|
||||
configure_file("${PROJECT_SOURCE_DIR}/src/rpt/_helpers.py" "${RPT_PYTHON_PACKAGE_DIR}/_helpers.py" COPYONLY)
|
||||
configure_file("${PROJECT_SOURCE_DIR}/src/rpt/py.typed" "${RPT_PYTHON_PACKAGE_DIR}/py.typed" COPYONLY)
|
||||
|
||||
nanobind_add_module(rpt_core_ext "${CMAKE_CURRENT_SOURCE_DIR}/rpt_module.cpp")
|
||||
|
||||
set_target_properties(rpt_core_ext PROPERTIES
|
||||
OUTPUT_NAME "_core"
|
||||
LIBRARY_OUTPUT_DIRECTORY "${RPT_PYTHON_PACKAGE_DIR}"
|
||||
)
|
||||
|
||||
target_link_libraries(rpt_core_ext PRIVATE rpt_core)
|
||||
target_include_directories(rpt_core_ext PRIVATE
|
||||
"${PROJECT_SOURCE_DIR}/rpt"
|
||||
)
|
||||
|
||||
nanobind_add_stub(rpt_core_stub
|
||||
MODULE rpt._core
|
||||
OUTPUT "${RPT_PYTHON_PACKAGE_DIR}/_core.pyi"
|
||||
PYTHON_PATH "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
DEPENDS rpt_core_ext
|
||||
)
|
||||
|
||||
install(TARGETS rpt_core_ext LIBRARY DESTINATION rpt)
|
||||
install(FILES "${RPT_PYTHON_PACKAGE_DIR}/_core.pyi" DESTINATION rpt)
|
||||
@@ -0,0 +1,121 @@
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/ndarray.h>
|
||||
#include <nanobind/stl/array.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "interface.hpp"
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
namespace
|
||||
{
|
||||
using PoseArray2D =
|
||||
nb::ndarray<nb::numpy, const float, nb::shape<-1, -1, -1, 3>, nb::c_contig>;
|
||||
using CountArray = nb::ndarray<nb::numpy, const uint32_t, nb::shape<-1>, nb::c_contig>;
|
||||
using RoomArray = nb::ndarray<nb::numpy, const float, nb::shape<2, 3>, nb::c_contig>;
|
||||
using PoseArray3D = nb::ndarray<nb::numpy, float, nb::shape<-1, -1, 4>, nb::c_contig>;
|
||||
|
||||
PoseBatch2D pose_batch_from_numpy(const PoseArray2D &poses_2d, const CountArray &person_counts)
|
||||
{
|
||||
if (poses_2d.shape(0) != person_counts.shape(0))
|
||||
{
|
||||
throw std::invalid_argument("poses_2d and person_counts must have the same number of views.");
|
||||
}
|
||||
|
||||
PoseBatch2D batch;
|
||||
batch.num_views = static_cast<size_t>(poses_2d.shape(0));
|
||||
batch.max_persons = static_cast<size_t>(poses_2d.shape(1));
|
||||
batch.num_joints = static_cast<size_t>(poses_2d.shape(2));
|
||||
batch.person_counts.assign(person_counts.data(), person_counts.data() + batch.num_views);
|
||||
|
||||
for (size_t i = 0; i < batch.person_counts.size(); ++i)
|
||||
{
|
||||
if (batch.person_counts[i] > batch.max_persons)
|
||||
{
|
||||
throw std::invalid_argument("person_counts entries must not exceed the padded person dimension.");
|
||||
}
|
||||
}
|
||||
|
||||
const size_t total_size = batch.num_views * batch.max_persons * batch.num_joints * 3;
|
||||
batch.data.resize(total_size);
|
||||
std::memcpy(batch.data.data(), poses_2d.data(), total_size * sizeof(float));
|
||||
return batch;
|
||||
}
|
||||
|
||||
std::array<std::array<float, 3>, 2> roomparams_from_numpy(const RoomArray &roomparams)
|
||||
{
|
||||
std::array<std::array<float, 3>, 2> result {};
|
||||
for (size_t i = 0; i < 2; ++i)
|
||||
{
|
||||
for (size_t j = 0; j < 3; ++j)
|
||||
{
|
||||
result[i][j] = roomparams(i, j);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
PoseArray3D pose_batch_to_numpy(PoseBatch3D batch)
|
||||
{
|
||||
auto *storage = new std::vector<float>(std::move(batch.data));
|
||||
nb::capsule owner(storage, [](void *value) noexcept
|
||||
{
|
||||
delete static_cast<std::vector<float> *>(value);
|
||||
});
|
||||
|
||||
const size_t shape[3] = {batch.num_persons, batch.num_joints, 4};
|
||||
return PoseArray3D(storage->data(), 3, shape, owner);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
NB_MODULE(_core, m)
|
||||
{
|
||||
nb::class_<Camera>(m, "Camera")
|
||||
.def(nb::init<>())
|
||||
.def_rw("name", &Camera::name)
|
||||
.def_rw("K", &Camera::K)
|
||||
.def_rw("DC", &Camera::DC)
|
||||
.def_rw("R", &Camera::R)
|
||||
.def_rw("T", &Camera::T)
|
||||
.def_rw("width", &Camera::width)
|
||||
.def_rw("height", &Camera::height)
|
||||
.def_rw("type", &Camera::type)
|
||||
.def("__repr__", [](const Camera &camera)
|
||||
{
|
||||
return camera.to_string();
|
||||
});
|
||||
|
||||
nb::class_<Triangulator>(m, "Triangulator")
|
||||
.def(nb::init<float, size_t>(),
|
||||
"min_match_score"_a = 0.95f,
|
||||
"min_group_size"_a = 1)
|
||||
.def(
|
||||
"triangulate_poses",
|
||||
[](Triangulator &self,
|
||||
const PoseArray2D &poses_2d,
|
||||
const CountArray &person_counts,
|
||||
const std::vector<Camera> &cameras,
|
||||
const RoomArray &roomparams,
|
||||
const std::vector<std::string> &joint_names)
|
||||
{
|
||||
PoseBatch2D pose_batch = pose_batch_from_numpy(poses_2d, person_counts);
|
||||
auto room = roomparams_from_numpy(roomparams);
|
||||
PoseBatch3D poses_3d = self.triangulate_poses(pose_batch, cameras, room, joint_names);
|
||||
return pose_batch_to_numpy(std::move(poses_3d));
|
||||
},
|
||||
"poses_2d"_a,
|
||||
"person_counts"_a,
|
||||
"cameras"_a,
|
||||
"roomparams"_a,
|
||||
"joint_names"_a)
|
||||
.def("reset", &Triangulator::reset)
|
||||
.def("print_stats", &Triangulator::print_stats);
|
||||
}
|
||||
Reference in New Issue
Block a user