diff --git a/src/testers/mcap_multi_record_tester.cpp b/src/testers/mcap_multi_record_tester.cpp index 2a0a175..dd1058f 100644 --- a/src/testers/mcap_multi_record_tester.cpp +++ b/src/testers/mcap_multi_record_tester.cpp @@ -13,7 +13,9 @@ #include #include #include +#include #include +#include #include namespace { @@ -31,6 +33,20 @@ constexpr int exit_code(const TesterExitCode code) { return static_cast(code); } +[[nodiscard]] +std::optional parse_compression(std::string_view raw) { + if (raw == "none") { + return cvmmap_streamer::McapCompression::None; + } + if (raw == "lz4") { + return cvmmap_streamer::McapCompression::Lz4; + } + if (raw == "zstd") { + return cvmmap_streamer::McapCompression::Zstd; + } + return std::nullopt; +} + } int main(int argc, char **argv) { @@ -43,13 +59,21 @@ int main(int argc, char **argv) { argc > 1 ? std::filesystem::path(argv[1]) : std::filesystem::temp_directory_path() / "cvmmap_streamer_multi_record_test.mcap"; + const auto compression = + argc > 2 + ? parse_compression(argv[2]).value_or(cvmmap_streamer::McapCompression::None) + : cvmmap_streamer::McapCompression::None; + if (argc > 2 && !parse_compression(argv[2])) { + spdlog::error("invalid compression '{}': expected none|lz4|zstd", argv[2]); + return exit_code(TesterExitCode::CreateError); + } if (output_path.has_parent_path()) { std::filesystem::create_directories(output_path.parent_path()); } auto sink = cvmmap_streamer::record::MultiMcapRecordSink::create( output_path.string(), - cvmmap_streamer::McapCompression::None); + compression); if (!sink) { spdlog::error("failed to create MCAP sink: {}", sink.error()); return exit_code(TesterExitCode::CreateError); diff --git a/third_party/mcap/include/mcap/writer.hpp b/third_party/mcap/include/mcap/writer.hpp index a32b054..416878d 100644 --- a/third_party/mcap/include/mcap/writer.hpp +++ b/third_party/mcap/include/mcap/writer.hpp @@ -312,6 +312,7 @@ public: private: std::vector uncompressedBuffer_; std::vector compressedBuffer_; + CompressionLevel compressionLevel_ = CompressionLevel::Default; ZSTD_CCtx_s* zstdContext_ = nullptr; }; #endif diff --git a/third_party/mcap/include/mcap/writer.inl b/third_party/mcap/include/mcap/writer.inl index d58cf24..86a629c 100644 --- a/third_party/mcap/include/mcap/writer.inl +++ b/third_party/mcap/include/mcap/writer.inl @@ -241,10 +241,13 @@ int ZStdCompressionLevel(CompressionLevel level) { // ZStdWriter ////////////////////////////////////////////////////////////////// -ZStdWriter::ZStdWriter(CompressionLevel compressionLevel, uint64_t chunkSize) { +ZStdWriter::ZStdWriter(CompressionLevel compressionLevel, uint64_t chunkSize) + : compressionLevel_(compressionLevel) { zstdContext_ = ZSTD_createCCtx(); - ZSTD_CCtx_setParameter(zstdContext_, ZSTD_c_compressionLevel, - internal::ZStdCompressionLevel(compressionLevel)); + if (zstdContext_ == nullptr) { + std::cerr << "ZSTD_createCCtx failed\n"; + std::abort(); + } uncompressedBuffer_.reserve(chunkSize); } @@ -259,15 +262,16 @@ void ZStdWriter::handleWrite(const std::byte* data, uint64_t size) { void ZStdWriter::end() { const auto dstCapacity = ZSTD_compressBound(uncompressedBuffer_.size()); compressedBuffer_.resize(dstCapacity); - const size_t dstSize = ZSTD_compress2(zstdContext_, compressedBuffer_.data(), dstCapacity, - uncompressedBuffer_.data(), uncompressedBuffer_.size()); + const size_t dstSize = + ZSTD_compressCCtx(zstdContext_, compressedBuffer_.data(), dstCapacity, + uncompressedBuffer_.data(), uncompressedBuffer_.size(), + internal::ZStdCompressionLevel(compressionLevel_)); if (ZSTD_isError(dstSize)) { const auto errCode = ZSTD_getErrorCode(dstSize); - std::cerr << "ZSTD_compress2 failed: " << ZSTD_getErrorName(dstSize) << " (" + std::cerr << "ZSTD_compressCCtx failed: " << ZSTD_getErrorName(dstSize) << " (" << ZSTD_getErrorString(errCode) << ")\n"; std::abort(); } - ZSTD_CCtx_reset(zstdContext_, ZSTD_reset_session_only); compressedBuffer_.resize(dstSize); }