diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index f1a2d3b8f2f..35e58fec035 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -45,6 +46,7 @@ DEFINE_bool( "True if the model_path passed is a file descriptor with the prefix \"fd:///\"."); using executorch::extension::FileDataLoader; +using executorch::extension::FileDescriptorDataLoader; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::HierarchicalAllocator; @@ -56,6 +58,33 @@ using executorch::runtime::Program; using executorch::runtime::Result; using executorch::runtime::Span; +static Result getProgram( + const bool is_fd_uri, + const char* model_path) { + // Create a loader to get the data of the program file. This demonstrates both + // FileDataLoader and FileDescriptorDataLoader. There are other DataLoaders + // that use mmap() or point to data that's already in memory, and users can + // create their own DataLoaders to load from arbitrary sources. + if (!is_fd_uri) { + Result loader = FileDataLoader::from(model_path); + + ET_CHECK_MSG( + loader.ok(), + "FileDataLoader::from() failed: 0x%" PRIx32, + (uint32_t)loader.error()); + return Program::load(&loader.get()); + } else { + Result loader = + FileDescriptorDataLoader::fromFileDescriptorUri(model_path); + + ET_CHECK_MSG( + loader.ok(), + "FileDescriptorDataLoader::fromFileDescriptorUri() failed: 0x%" PRIx32, + (uint32_t)loader.error()); + return Program::load(&loader.get()); + } +} + int main(int argc, char** argv) { executorch::runtime::runtime_init(); @@ -75,18 +104,9 @@ int main(int argc, char** argv) { const char* model_path = FLAGS_model_path.c_str(); const bool is_fd_uri = FLAGS_is_fd_uri; - Result loader = is_fd_uri - ? FileDataLoader::fromFileDescriptorUri(model_path) - : FileDataLoader::from(model_path); - - ET_CHECK_MSG( - loader.ok(), - "FileDataLoader::from() failed: 0x%" PRIx32, - (uint32_t)loader.error()); - // Parse the program file. This is immutable, and can also be reused between // multiple execution invocations across multiple threads. - Result program = Program::load(&loader.get()); + Result program = getProgram(is_fd_uri, model_path); if (!program.ok()) { ET_LOG(Error, "Failed to parse model file %s", model_path); return 1; diff --git a/examples/portable/executor_runner/targets.bzl b/examples/portable/executor_runner/targets.bzl index 9cddaa4ed77..83c63d3a411 100644 --- a/examples/portable/executor_runner/targets.bzl +++ b/examples/portable/executor_runner/targets.bzl @@ -15,6 +15,7 @@ def define_common_targets(): deps = [ "//executorch/runtime/executor:program", "//executorch/extension/data_loader:file_data_loader", + "//executorch/extension/data_loader:file_descriptor_data_loader", "//executorch/extension/evalue_util:print_evalue", "//executorch/extension/runner_util:inputs", ], diff --git a/extension/data_loader/file_data_loader.cpp b/extension/data_loader/file_data_loader.cpp index f5a3b94d843..1d097cfd989 100644 --- a/extension/data_loader/file_data_loader.cpp +++ b/extension/data_loader/file_data_loader.cpp @@ -43,8 +43,6 @@ namespace extension { namespace { -static constexpr char kFdFilesystemPrefix[] = "fd:///"; - /** * Returns true if the value is an integer power of 2. */ @@ -76,36 +74,25 @@ FileDataLoader::~FileDataLoader() { ::close(fd_); } -static Result getFDFromUri(const char* file_descriptor_uri) { - // check if the uri starts with the prefix "fd://" +Result FileDataLoader::from( + const char* file_name, + size_t alignment) { ET_CHECK_OR_RETURN_ERROR( - strncmp( - file_descriptor_uri, - kFdFilesystemPrefix, - strlen(kFdFilesystemPrefix)) == 0, + is_power_of_2(alignment), InvalidArgument, - "File descriptor uri (%s) does not start with %s", - file_descriptor_uri, - kFdFilesystemPrefix); - - // strip "fd:///" from the uri - int fd_len = strlen(file_descriptor_uri) - strlen(kFdFilesystemPrefix); - char fd_without_prefix[fd_len + 1]; - memcpy( - fd_without_prefix, - &file_descriptor_uri[strlen(kFdFilesystemPrefix)], - fd_len); - fd_without_prefix[fd_len] = '\0'; + "Alignment %zu is not a power of 2", + alignment); - // check if remaining fd string is a valid integer - int fd = ::atoi(fd_without_prefix); - return fd; -} + // Use open() instead of fopen() to avoid the layer of buffering that + // fopen() does. We will be reading large portions of the file in one shot, + // so buffering does not help. + int fd = ::open(file_name, O_RDONLY); + if (fd < 0) { + ET_LOG( + Error, "Failed to open %s: %s (%d)", file_name, strerror(errno), errno); + return Error::AccessFailed; + } -Result FileDataLoader::fromFileDescriptor( - const char* file_name, - const int fd, - size_t alignment) { // Cache the file size. struct stat st; int err = ::fstat(fd, &st); @@ -132,47 +119,6 @@ Result FileDataLoader::fromFileDescriptor( return FileDataLoader(fd, file_size, alignment, file_name_copy); } -Result FileDataLoader::fromFileDescriptorUri( - const char* file_descriptor_uri, - size_t alignment) { - ET_CHECK_OR_RETURN_ERROR( - is_power_of_2(alignment), - InvalidArgument, - "Alignment %zu is not a power of 2", - alignment); - - auto parsed_fd = getFDFromUri(file_descriptor_uri); - if (!parsed_fd.ok()) { - return parsed_fd.error(); - } - - int fd = parsed_fd.get(); - - return fromFileDescriptor(file_descriptor_uri, fd, alignment); -} - -Result FileDataLoader::from( - const char* file_name, - size_t alignment) { - ET_CHECK_OR_RETURN_ERROR( - is_power_of_2(alignment), - InvalidArgument, - "Alignment %zu is not a power of 2", - alignment); - - // Use open() instead of fopen() to avoid the layer of buffering that - // fopen() does. We will be reading large portions of the file in one shot, - // so buffering does not help. - int fd = ::open(file_name, O_RDONLY); - if (fd < 0) { - ET_LOG( - Error, "Failed to open %s: %s (%d)", file_name, strerror(errno), errno); - return Error::AccessFailed; - } - - return fromFileDescriptor(file_name, fd, alignment); -} - namespace { /** * FreeableBuffer::FreeFn-compatible callback. diff --git a/extension/data_loader/file_data_loader.h b/extension/data_loader/file_data_loader.h index 959684137b8..7cf2a92c4ad 100644 --- a/extension/data_loader/file_data_loader.h +++ b/extension/data_loader/file_data_loader.h @@ -26,27 +26,6 @@ namespace extension { */ class FileDataLoader final : public executorch::runtime::DataLoader { public: - /** - * Creates a new FileDataLoader that wraps the named file descriptor, and the - * ownership of the file descriptor is passed. This helper is used when ET is - * running in a process that does not have access to the filesystem, and the - * caller is able to open the file and pass the file descriptor. - * - * @param[in] file_descriptor_uri File descriptor with the prefix "fd:///", - * followed by the file descriptor number. - * @param[in] alignment Alignment in bytes of pointers returned by this - * instance. Must be a power of two. - * - * @returns A new FileDataLoader on success. - * @retval Error::InvalidArgument `alignment` is not a power of two. - * @retval Error::AccessFailed `file_name` could not be opened, or its size - * could not be found. - * @retval Error::MemoryAllocationFailed Internal memory allocation failure. - */ - static executorch::runtime::Result fromFileDescriptorUri( - const char* file_descriptor_uri, - size_t alignment = alignof(std::max_align_t)); - /** * Creates a new FileDataLoader that wraps the named file. * @@ -100,11 +79,6 @@ class FileDataLoader final : public executorch::runtime::DataLoader { void* buffer) const override; private: - static executorch::runtime::Result fromFileDescriptor( - const char* file_name, - const int fd, - size_t alignment = alignof(std::max_align_t)); - FileDataLoader( int fd, size_t file_size, diff --git a/extension/data_loader/file_descriptor_data_loader.cpp b/extension/data_loader/file_descriptor_data_loader.cpp new file mode 100644 index 00000000000..48e81fd7062 --- /dev/null +++ b/extension/data_loader/file_descriptor_data_loader.cpp @@ -0,0 +1,292 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; + +namespace executorch { +namespace extension { + +namespace { + +static constexpr char kFdFilesystemPrefix[] = "fd:///"; + +/** + * Returns true if the value is an integer power of 2. + */ +static bool is_power_of_2(size_t value) { + return value > 0 && (value & ~(value - 1)) == value; +} + +/** + * Returns the next alignment for a given pointer. + */ +static uint8_t* align_pointer(void* ptr, size_t alignment) { + intptr_t addr = reinterpret_cast(ptr); + if ((addr & (alignment - 1)) == 0) { + // Already aligned. + return reinterpret_cast(ptr); + } + // Bump forward. + addr = (addr | (alignment - 1)) + 1; + return reinterpret_cast(addr); +} +} // namespace + +FileDescriptorDataLoader::~FileDescriptorDataLoader() { + // file_descriptor_uri_ can be nullptr if this instance was moved from, but + // freeing a null pointer is safe. + std::free(const_cast(file_descriptor_uri_)); + // fd_ can be -1 if this instance was moved from, but closing a negative fd is + // safe (though it will return an error). + ::close(fd_); +} + +static Result getFDFromUri(const char* file_descriptor_uri) { + // check if the uri starts with the prefix "fd://" + ET_CHECK_OR_RETURN_ERROR( + strncmp( + file_descriptor_uri, + kFdFilesystemPrefix, + strlen(kFdFilesystemPrefix)) == 0, + InvalidArgument, + "File descriptor uri (%s) does not start with %s", + file_descriptor_uri, + kFdFilesystemPrefix); + + // strip "fd:///" from the uri + int fd_len = strlen(file_descriptor_uri) - strlen(kFdFilesystemPrefix); + char fd_without_prefix[fd_len + 1]; + memcpy( + fd_without_prefix, + &file_descriptor_uri[strlen(kFdFilesystemPrefix)], + fd_len); + fd_without_prefix[fd_len] = '\0'; + + // check if remaining fd string is a valid integer + int fd = ::atoi(fd_without_prefix); + return fd; +} + +Result +FileDescriptorDataLoader::fromFileDescriptorUri( + const char* file_descriptor_uri, + size_t alignment) { + ET_CHECK_OR_RETURN_ERROR( + is_power_of_2(alignment), + InvalidArgument, + "Alignment %zu is not a power of 2", + alignment); + + auto parsed_fd = getFDFromUri(file_descriptor_uri); + if (!parsed_fd.ok()) { + return parsed_fd.error(); + } + + int fd = parsed_fd.get(); + + // Cache the file size. + struct stat st; + int err = ::fstat(fd, &st); + if (err < 0) { + ET_LOG( + Error, + "Could not get length of %s: %s (%d)", + file_descriptor_uri, + ::strerror(errno), + errno); + ::close(fd); + return Error::AccessFailed; + } + size_t file_size = st.st_size; + + // Copy the filename so we can print better debug messages if reads fail. + const char* file_descriptor_uri_copy = ::strdup(file_descriptor_uri); + if (file_descriptor_uri_copy == nullptr) { + ET_LOG(Error, "strdup(%s) failed", file_descriptor_uri); + ::close(fd); + return Error::MemoryAllocationFailed; + } + + return FileDescriptorDataLoader( + fd, file_size, alignment, file_descriptor_uri_copy); +} + +namespace { +/** + * FreeableBuffer::FreeFn-compatible callback. + * + * `context` is actually a ptrdiff_t value (not a pointer) that contains the + * offset in bytes between `data` and the actual pointer to free. + */ +void FreeSegment(void* context, void* data, ET_UNUSED size_t size) { + ptrdiff_t offset = reinterpret_cast(context); + ET_DCHECK_MSG(offset >= 0, "Unexpected offset %ld", (long int)offset); + std::free(static_cast(data) - offset); +} +} // namespace + +Result FileDescriptorDataLoader::load( + size_t offset, + size_t size, + ET_UNUSED const DataLoader::SegmentInfo& segment_info) const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + ET_CHECK_OR_RETURN_ERROR( + offset + size <= file_size_, + InvalidArgument, + "File %s: offset %zu + size %zu > file_size_ %zu", + file_descriptor_uri_, + offset, + size, + file_size_); + + // Don't bother allocating/freeing for empty segments. + if (size == 0) { + return FreeableBuffer(nullptr, 0, /*free_fn=*/nullptr); + } + + // Allocate memory for the FreeableBuffer. + size_t alloc_size = size; + if (alignment_ > alignof(std::max_align_t)) { + // malloc() will align to smaller values, but we must manually align to + // larger values. + alloc_size += alignment_; + } + void* buffer = std::malloc(alloc_size); + if (buffer == nullptr) { + ET_LOG( + Error, + "Reading from %s at offset %zu: malloc(%zd) failed", + file_descriptor_uri_, + offset, + size); + return Error::MemoryAllocationFailed; + } + + // Align. + void* aligned_buffer = align_pointer(buffer, alignment_); + + // Assert that the alignment didn't overflow the buffer. + ET_DCHECK_MSG( + reinterpret_cast(aligned_buffer) + size <= + reinterpret_cast(buffer) + alloc_size, + "aligned_buffer %p + size %zu > buffer %p + alloc_size %zu", + aligned_buffer, + size, + buffer, + alloc_size); + + auto err = load_into(offset, size, segment_info, aligned_buffer); + if (err != Error::Ok) { + // Free `buffer`, which is what malloc() gave us, not `aligned_buffer`. + std::free(buffer); + return err; + } + + // We can't naively free this pointer, since it may not be what malloc() gave + // us. Pass the offset to the real buffer as context. This is the number of + // bytes that need to be subtracted from the FreeableBuffer::data() pointer to + // find the actual pointer to free. + return FreeableBuffer( + aligned_buffer, + size, + FreeSegment, + /*free_fn_context=*/ + reinterpret_cast( + // Using signed types here because it will produce a signed ptrdiff_t + // value, though for us it will always be non-negative. + reinterpret_cast(aligned_buffer) - + reinterpret_cast(buffer))); +} + +Result FileDescriptorDataLoader::size() const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + return file_size_; +} + +ET_NODISCARD Error FileDescriptorDataLoader::load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const { + ET_CHECK_OR_RETURN_ERROR( + // Probably had its value moved to another instance. + fd_ >= 0, + InvalidState, + "Uninitialized"); + ET_CHECK_OR_RETURN_ERROR( + offset + size <= file_size_, + InvalidArgument, + "File %s: offset %zu + size %zu > file_size_ %zu", + file_descriptor_uri_, + offset, + size, + file_size_); + ET_CHECK_OR_RETURN_ERROR( + buffer != nullptr, InvalidArgument, "Provided buffer cannot be null"); + + // Read the data into the aligned address. + size_t needed = size; + uint8_t* buf = reinterpret_cast(buffer); + + while (needed > 0) { + // Reads on macOS will fail with EINVAL if size > INT32_MAX. + const auto chunk_size = std::min( + needed, static_cast(std::numeric_limits::max())); + const auto nread = ::pread(fd_, buf, chunk_size, offset); + if (nread < 0 && errno == EINTR) { + // Interrupted by a signal; zero bytes read. + continue; + } + if (nread <= 0) { + // nread == 0 means EOF, which we shouldn't see if we were able to read + // the full amount. nread < 0 means an error occurred. + ET_LOG( + Error, + "Reading from %s: failed to read %zu bytes at offset %zu: %s", + file_descriptor_uri_, + size, + offset, + nread == 0 ? "EOF" : strerror(errno)); + return Error::AccessFailed; + } + needed -= nread; + buf += nread; + offset += nread; + } + return Error::Ok; +} + +} // namespace extension +} // namespace executorch diff --git a/extension/data_loader/file_descriptor_data_loader.h b/extension/data_loader/file_descriptor_data_loader.h new file mode 100644 index 00000000000..6f51f0f7a62 --- /dev/null +++ b/extension/data_loader/file_descriptor_data_loader.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace executorch { +namespace extension { + +/** + * A DataLoader that loads segments from a file descriptor, allocating the + * memory with `malloc()`. This data loader is used when ET is running in a + * process that does not have access to the filesystem, and the caller is able + * to open the file and pass the file descriptor. + * + * Note that this will keep the file open for the duration of its lifetime, to + * avoid the overhead of opening it again for every load() call. + */ +class FileDescriptorDataLoader final : public executorch::runtime::DataLoader { + public: + /** + * Creates a new FileDescriptorDataLoader that wraps the named file + * descriptor, and the ownership of the file descriptor is passed. + * + * @param[in] file_descriptor_uri File descriptor with the prefix "fd:///", + * followed by the file descriptor number. + * @param[in] alignment Alignment in bytes of pointers returned by this + * instance. Must be a power of two. + * + * @returns A new FileDescriptorDataLoader on success. + * @retval Error::InvalidArgument `alignment` is not a power of two. + * @retval Error::AccessFailed `file_descriptor_uri` is incorrectly formatted, + * or its size could not be found. + * @retval Error::MemoryAllocationFailed Internal memory allocation failure. + */ + static executorch::runtime::Result + fromFileDescriptorUri( + const char* file_descriptor_uri, + size_t alignment = alignof(std::max_align_t)); + + // Movable to be compatible with Result. + FileDescriptorDataLoader(FileDescriptorDataLoader&& rhs) noexcept + : file_descriptor_uri_(rhs.file_descriptor_uri_), + file_size_(rhs.file_size_), + alignment_(rhs.alignment_), + fd_(rhs.fd_) { + const_cast(rhs.file_descriptor_uri_) = nullptr; + const_cast(rhs.file_size_) = 0; + const_cast(rhs.alignment_) = 0; + const_cast(rhs.fd_) = -1; + } + + ~FileDescriptorDataLoader() override; + + ET_NODISCARD + executorch::runtime::Result load( + size_t offset, + size_t size, + const DataLoader::SegmentInfo& segment_info) const override; + + ET_NODISCARD executorch::runtime::Result size() const override; + + ET_NODISCARD executorch::runtime::Error load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const override; + + private: + FileDescriptorDataLoader( + int fd, + size_t file_size, + size_t alignment, + const char* file_descriptor_uri) + : file_descriptor_uri_(file_descriptor_uri), + file_size_(file_size), + alignment_(alignment), + fd_(fd) {} + + // Not safely copyable. + FileDescriptorDataLoader(const FileDescriptorDataLoader&) = delete; + FileDescriptorDataLoader& operator=(const FileDescriptorDataLoader&) = delete; + FileDescriptorDataLoader& operator=(FileDescriptorDataLoader&&) = delete; + + const char* const file_descriptor_uri_; // Owned by the instance. + const size_t file_size_; + const size_t alignment_; + const int fd_; // Owned by the instance. +}; + +} // namespace extension +} // namespace executorch + +namespace torch { +namespace executor { +namespace util { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::extension::FileDescriptorDataLoader; +} // namespace util +} // namespace executor +} // namespace torch diff --git a/extension/data_loader/targets.bzl b/extension/data_loader/targets.bzl index 4886df03a76..fcc7cba5419 100644 --- a/extension/data_loader/targets.bzl +++ b/extension/data_loader/targets.bzl @@ -52,6 +52,21 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "file_descriptor_data_loader", + srcs = ["file_descriptor_data_loader.cpp"], + exported_headers = ["file_descriptor_data_loader.h"], + visibility = [ + "//executorch/test/...", + "//executorch/runtime/executor/test/...", + "//executorch/extension/data_loader/test/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/runtime/core:core", + ], + ) + runtime.cxx_library( name = "mmap_data_loader", srcs = ["mmap_data_loader.cpp"], diff --git a/extension/data_loader/test/file_data_loader_test.cpp b/extension/data_loader/test/file_data_loader_test.cpp index b8921aebb54..1d4f4c16196 100644 --- a/extension/data_loader/test/file_data_loader_test.cpp +++ b/extension/data_loader/test/file_data_loader_test.cpp @@ -40,103 +40,6 @@ class FileDataLoaderTest : public ::testing::TestWithParam { } }; -TEST_P(FileDataLoaderTest, InBoundsFileDescriptorLoadsSucceed) { - // Write some heterogeneous data to a file. - uint8_t data[256]; - for (int i = 0; i < sizeof(data); ++i) { - data[i] = i; - } - TempFile tf(data, sizeof(data)); - - int fd = ::open(tf.path().c_str(), O_RDONLY); - - // Wrap it in a loader. - Result fdl = FileDataLoader::fromFileDescriptorUri( - ("fd:///" + std::to_string(fd)).c_str(), alignment()); - ASSERT_EQ(fdl.error(), Error::Ok); - - // size() should succeed and reflect the total size. - Result size = fdl->size(); - ASSERT_EQ(size.error(), Error::Ok); - EXPECT_EQ(*size, sizeof(data)); - - // Load the first bytes of the data. - { - Result fb = fdl->load( - /*offset=*/0, - /*size=*/8, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_ALIGNED(fb->data(), alignment()); - EXPECT_EQ(fb->size(), 8); - EXPECT_EQ( - 0, - std::memcmp( - fb->data(), - "\x00\x01\x02\x03" - "\x04\x05\x06\x07", - fb->size())); - - // Freeing should release the buffer and clear out the segment. - fb->Free(); - EXPECT_EQ(fb->size(), 0); - EXPECT_EQ(fb->data(), nullptr); - - // Safe to call multiple times. - fb->Free(); - } - - // Load the last few bytes of the data, a different size than the first time. - { - Result fb = fdl->load( - /*offset=*/sizeof(data) - 3, - /*size=*/3, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_ALIGNED(fb->data(), alignment()); - EXPECT_EQ(fb->size(), 3); - EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); - } - - // Loading all of the data succeeds. - { - Result fb = fdl->load( - /*offset=*/0, - /*size=*/sizeof(data), - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_ALIGNED(fb->data(), alignment()); - EXPECT_EQ(fb->size(), sizeof(data)); - EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); - } - - // Loading zero-sized data succeeds, even at the end of the data. - { - Result fb = fdl->load( - /*offset=*/sizeof(data), - /*size=*/0, - DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); - ASSERT_EQ(fb.error(), Error::Ok); - EXPECT_EQ(fb->size(), 0); - } -} - -TEST_P(FileDataLoaderTest, FileDescriptorLoadPrefixFail) { - // Write some heterogeneous data to a file. - uint8_t data[256]; - for (int i = 0; i < sizeof(data); ++i) { - data[i] = i; - } - TempFile tf(data, sizeof(data)); - - int fd = ::open(tf.path().c_str(), O_RDONLY); - - // Wrap it in a loader. - Result fdl = FileDataLoader::fromFileDescriptorUri( - std::to_string(fd).c_str(), alignment()); - ASSERT_EQ(fdl.error(), Error::InvalidArgument); -} - TEST_P(FileDataLoaderTest, InBoundsLoadsSucceed) { // Write some heterogeneous data to a file. uint8_t data[256]; diff --git a/extension/data_loader/test/file_descriptor_data_loader_test.cpp b/extension/data_loader/test/file_descriptor_data_loader_test.cpp new file mode 100644 index 00000000000..0258611cbd7 --- /dev/null +++ b/extension/data_loader/test/file_descriptor_data_loader_test.cpp @@ -0,0 +1,359 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include +#include +#include +#include + +using namespace ::testing; +using executorch::extension::FileDescriptorDataLoader; +using executorch::extension::testing::TempFile; +using executorch::runtime::DataLoader; +using executorch::runtime::Error; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; + +class FileDescriptorDataLoaderTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + } + + // The alignment in bytes that tests should use. The values are set by the + // list in the INSTANTIATE_TEST_SUITE_P call below. + size_t alignment() const { + return GetParam(); + } +}; + +TEST_P(FileDescriptorDataLoaderTest, InBoundsFileDescriptorLoadsSucceed) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); + + // Load the first bytes of the data. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/8, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 8); + EXPECT_EQ( + 0, + std::memcmp( + fb->data(), + "\x00\x01\x02\x03" + "\x04\x05\x06\x07", + fb->size())); + + // Freeing should release the buffer and clear out the segment. + fb->Free(); + EXPECT_EQ(fb->size(), 0); + EXPECT_EQ(fb->data(), nullptr); + + // Safe to call multiple times. + fb->Free(); + } + + // Load the last few bytes of the data, a different size than the first time. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) - 3, + /*size=*/3, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 3); + EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); + } + + // Loading all of the data succeeds. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), sizeof(data)); + EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); + } + + // Loading zero-sized data succeeds, even at the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data), + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_EQ(fb->size(), 0); + } +} + +TEST_P(FileDescriptorDataLoaderTest, FileDescriptorLoadPrefixFail) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + std::to_string(fd).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::InvalidArgument); +} + +TEST_P(FileDescriptorDataLoaderTest, InBoundsLoadsSucceed) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); + + // Load the first bytes of the data. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/8, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 8); + EXPECT_EQ( + 0, + std::memcmp( + fb->data(), + "\x00\x01\x02\x03" + "\x04\x05\x06\x07", + fb->size())); + + // Freeing should release the buffer and clear out the segment. + fb->Free(); + EXPECT_EQ(fb->size(), 0); + EXPECT_EQ(fb->data(), nullptr); + + // Safe to call multiple times. + fb->Free(); + } + + // Load the last few bytes of the data, a different size than the first time. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) - 3, + /*size=*/3, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), 3); + EXPECT_EQ(0, std::memcmp(fb->data(), "\xfd\xfe\xff", fb->size())); + } + + // Loading all of the data succeeds. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + EXPECT_EQ(fb->size(), sizeof(data)); + EXPECT_EQ(0, std::memcmp(fb->data(), data, fb->size())); + } + + // Loading zero-sized data succeeds, even at the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data), + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_EQ(fb->size(), 0); + } +} + +TEST_P(FileDescriptorDataLoaderTest, OutOfBoundsLoadFails) { + // Create a temp file; contents don't matter. + uint8_t data[256] = {}; + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // Loading beyond the end of the data should fail. + { + Result fb = fdl->load( + /*offset=*/0, + /*size=*/sizeof(data) + 1, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + EXPECT_NE(fb.error(), Error::Ok); + } + + // Loading zero bytes still fails if it's past the end of the data. + { + Result fb = fdl->load( + /*offset=*/sizeof(data) + 1, + /*size=*/0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + EXPECT_NE(fb.error(), Error::Ok); + } +} + +TEST_P(FileDescriptorDataLoaderTest, BadAlignmentFails) { + // Create a temp file; contents don't matter. + uint8_t data[256] = {}; + TempFile tf(data, sizeof(data)); + + // Creating a loader with default alignment works fine. + { + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + } + + // Bad alignments fail. + const std::vector bad_alignments = {0, 3, 5, 17}; + for (size_t bad_alignment : bad_alignments) { + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), bad_alignment); + ASSERT_EQ(fdl.error(), Error::InvalidArgument); + } +} + +// Tests that the move ctor works. +TEST_P(FileDescriptorDataLoaderTest, MoveCtor) { + // Create a loader. + std::string contents = "FILE_CONTENTS"; + TempFile tf(contents); + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + EXPECT_EQ(fdl->size().get(), contents.size()); + + // Move it into another instance. + FileDescriptorDataLoader fdl2(std::move(*fdl)); + + // Old loader should now be invalid. + EXPECT_EQ( + fdl->load( + 0, + 0, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)) + .error(), + Error::InvalidState); + EXPECT_EQ(fdl->size().error(), Error::InvalidState); + + // New loader should point to the file. + EXPECT_EQ(fdl2.size().get(), contents.size()); + Result fb = fdl2.load( + /*offset=*/0, + contents.size(), + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); + ASSERT_EQ(fb.error(), Error::Ok); + EXPECT_ALIGNED(fb->data(), alignment()); + ASSERT_EQ(fb->size(), contents.size()); + EXPECT_EQ(0, std::memcmp(fb->data(), contents.data(), fb->size())); +} + +// Test that the deprecated From method (capital 'F') still works. +TEST_P(FileDescriptorDataLoaderTest, DEPRECATEDFrom) { + // Write some heterogeneous data to a file. + uint8_t data[256]; + for (int i = 0; i < sizeof(data); ++i) { + data[i] = i; + } + TempFile tf(data, sizeof(data)); + + int fd = ::open(tf.path().c_str(), O_RDONLY); + + // Wrap it in a loader. + Result fdl = + FileDescriptorDataLoader::fromFileDescriptorUri( + ("fd:///" + std::to_string(fd)).c_str(), alignment()); + ASSERT_EQ(fdl.error(), Error::Ok); + + // size() should succeed and reflect the total size. + Result size = fdl->size(); + ASSERT_EQ(size.error(), Error::Ok); + EXPECT_EQ(*size, sizeof(data)); +} + +// Run all FileDescriptorDataLoaderTests multiple times, varying the return +// value of `GetParam()` based on the `testing::Values` list. The tests will +// interpret the value as "alignment". +INSTANTIATE_TEST_SUITE_P( + VariedSegments, + FileDescriptorDataLoaderTest, + testing::Values( + 1, + 4, + alignof(std::max_align_t), + 2 * alignof(std::max_align_t), + 128, + 1024)); diff --git a/extension/data_loader/test/targets.bzl b/extension/data_loader/test/targets.bzl index 9c83d6d56b2..d424413c1bf 100644 --- a/extension/data_loader/test/targets.bzl +++ b/extension/data_loader/test/targets.bzl @@ -38,6 +38,17 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "file_descriptor_data_loader_test", + srcs = [ + "file_descriptor_data_loader_test.cpp", + ], + deps = [ + "//executorch/extension/testing_util:temp_file", + "//executorch/extension/data_loader:file_descriptor_data_loader", + ], + ) + runtime.cxx_test( name = "mmap_data_loader_test", srcs = [