From cddc54f61aaff9d731f56405ff29edd1f69157a5 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 12 Jun 2019 17:16:44 +0000 Subject: [PATCH 1/3] Add batch support for VideoDataset Signed-off-by: Yong Tang --- tensorflow_io/video/BUILD | 27 +-- .../video/kernels/video_dataset_ops.cc | 187 ------------------ tensorflow_io/video/kernels/video_input.cc | 88 +++++++++ .../ops/{dataset_ops.cc => video_ops.cc} | 23 ++- tensorflow_io/video/python/ops/video_ops.py | 17 +- tests/test_video.py | 2 +- 6 files changed, 126 insertions(+), 218 deletions(-) delete mode 100644 tensorflow_io/video/kernels/video_dataset_ops.cc create mode 100644 tensorflow_io/video/kernels/video_input.cc rename tensorflow_io/video/ops/{dataset_ops.cc => video_ops.cc} (62%) diff --git a/tensorflow_io/video/BUILD b/tensorflow_io/video/BUILD index 049674cf3..1a07f6ce9 100644 --- a/tensorflow_io/video/BUILD +++ b/tensorflow_io/video/BUILD @@ -6,23 +6,20 @@ cc_binary( name = "python/ops/_video_ops_ffmpeg_3.4.so", srcs = [ "kernels/ffmpeg.cc", - "kernels/video_dataset_ops.cc", + "kernels/video_input.cc", "kernels/video_reader.h", - "ops/dataset_ops.cc", + "ops/video_ops.cc", ], copts = [ "-pthread", "-std=c++11", "-DNDEBUG", ], - data = [], includes = ["."], - linkopts = [], linkshared = 1, deps = [ "@ffmpeg_3_4//:ffmpeg", - "@local_config_tf//:libtensorflow_framework", - "@local_config_tf//:tf_header_lib", + "//tensorflow_io/core:dataset_ops", ], ) @@ -30,23 +27,20 @@ cc_binary( name = "python/ops/_video_ops_ffmpeg_2.8.so", srcs = [ "kernels/ffmpeg.cc", - "kernels/video_dataset_ops.cc", + "kernels/video_input.cc", "kernels/video_reader.h", - "ops/dataset_ops.cc", + "ops/video_ops.cc", ], copts = [ "-pthread", "-std=c++11", "-DNDEBUG", ], - data = [], includes = ["."], - linkopts = [], linkshared = 1, deps = [ "@ffmpeg_2_8//:ffmpeg", - "@local_config_tf//:libtensorflow_framework", - "@local_config_tf//:tf_header_lib", + "//tensorflow_io/core:dataset_ops", ], ) @@ -54,22 +48,19 @@ cc_binary( name = "python/ops/_video_ops_libav_9.20.so", srcs = [ "kernels/ffmpeg.cc", - "kernels/video_dataset_ops.cc", + "kernels/video_input.cc", "kernels/video_reader.h", - "ops/dataset_ops.cc", + "ops/video_ops.cc", ], copts = [ "-pthread", "-std=c++11", "-DNDEBUG", ], - data = [], includes = ["."], - linkopts = [], linkshared = 1, deps = [ "@libav_9_20//:libav", - "@local_config_tf//:libtensorflow_framework", - "@local_config_tf//:tf_header_lib", + "//tensorflow_io/core:dataset_ops", ], ) diff --git a/tensorflow_io/video/kernels/video_dataset_ops.cc b/tensorflow_io/video/kernels/video_dataset_ops.cc deleted file mode 100644 index 9d6852aa8..000000000 --- a/tensorflow_io/video/kernels/video_dataset_ops.cc +++ /dev/null @@ -1,187 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "tensorflow/core/platform/file_system.h" -#include "kernels/video_reader.h" - -namespace tensorflow { -namespace data { -namespace { - -static mutex mu(LINKER_INITIALIZED); -static unsigned count(0); -void VideoReaderInit() { - mutex_lock lock(mu); - count++; - if (count == 1) { - // Register all formats and codecs - av_register_all(); - } -} - -class VideoDatasetOp : public DatasetOpKernel { - public: - using DatasetOpKernel::DatasetOpKernel; - explicit VideoDatasetOp(OpKernelConstruction* ctx) - : DatasetOpKernel(ctx) { - } - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - const Tensor* filenames_tensor; - OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); - OP_REQUIRES( - ctx, filenames_tensor->dims() <= 1, - errors::InvalidArgument("`filenames` must be a scalar or a vector.")); - - std::vector filenames; - filenames.reserve(filenames_tensor->NumElements()); - for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); - } - - *output = new Dataset(ctx, filenames); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const std::vector& filenames) - : DatasetBase(DatasetContext(ctx)), - filenames_(filenames) {} - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr( - new Iterator({this, strings::StrCat(prefix, "::Video")})); - } - - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = new DataTypeVector({DT_UINT8}); - return *dtypes; - } - - const std::vector& output_shapes() const override { - static std::vector* shapes = - new std::vector({{-1, -1, 3}}); - return *shapes; - } - - string DebugString() const override { - return "VideoDatasetOp::Dataset"; - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* filenames = nullptr; - TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - do { - // We are currently processing a file, so try to read the next record. - if (reader_) { - int num_bytes, height, width; - uint8_t *value; - Status status = reader_->ReadFrame(&num_bytes, &value, &height, &width); - if (!errors::IsOutOfRange(status)) { - TF_RETURN_IF_ERROR(status); - - Tensor value_tensor(ctx->allocator({}), DT_UINT8, {height, width, 3}); - std::memcpy(reinterpret_cast(value_tensor.flat().data()), reinterpret_cast(value), num_bytes * sizeof(uint8_t)); - out_tensors->emplace_back(std::move(value_tensor)); - - *end_of_sequence = false; - return Status::OK(); - } - // We have reached the end of the current file, so maybe - // move on to next file. - ResetStreamsLocked(); - ++current_file_index_; - } - - // Iteration ends when there are no more files to process. - if (current_file_index_ == dataset()->filenames_.size()) { - *end_of_sequence = true; - return Status::OK(); - } - - TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); - } while (true); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - return errors::Unimplemented("SaveInternal is currently not supported"); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - return errors::Unimplemented( - "RestoreInternal is currently not supported"); - } - - private: - // Sets up Video streams to read from the topic at - // `current_file_index_`. - Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - VideoReaderInit(); - - if (current_file_index_ >= dataset()->filenames_.size()) { - return errors::InvalidArgument( - "current_file_index_:", current_file_index_, - " >= filenames_.size():", dataset()->filenames_.size()); - } - - // Actually move on to next file. - const string& filename = dataset()->filenames_[current_file_index_]; - reader_.reset(new video::VideoReader(filename)); - return reader_->ReadHeader(); - return Status::OK(); - } - - // Resets all Video streams. - void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - reader_.reset(); - } - - mutex mu_; - size_t current_file_index_ GUARDED_BY(mu_) = 0; - std::unique_ptr reader_ GUARDED_BY(mu_); - }; - - const std::vector filenames_; - }; -}; - -REGISTER_KERNEL_BUILDER(Name("VideoDataset").Device(DEVICE_CPU), - VideoDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow_io/video/kernels/video_input.cc b/tensorflow_io/video/kernels/video_input.cc new file mode 100644 index 000000000..09dd110f1 --- /dev/null +++ b/tensorflow_io/video/kernels/video_input.cc @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "kernels/dataset_ops.h" +#include "kernels/video_reader.h" + +namespace tensorflow { +namespace data { + +static mutex mu(LINKER_INITIALIZED); +static unsigned count(0); +void VideoReaderInit() { + mutex_lock lock(mu); + count++; + if (count == 1) { + // Register all formats and codecs + av_register_all(); + } +} + +class VideoInput: public FileInput { + public: + Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { + if (state.get() == nullptr) { + VideoReaderInit(); + state.reset(new video::VideoReader(filename())); + TF_RETURN_IF_ERROR(state.get()->ReadHeader()); + } + // Read the first frame to get height and width + int num_bytes, height, width; + uint8_t *value; + Status status = state.get()->ReadFrame(&num_bytes, &value, &height, &width); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return status; + } + if (!status.ok()) { + return Status::OK(); + } + Tensor value_tensor(ctx->allocator({}), DT_UINT8, {record_to_read, height, width, 3}); + std::memcpy(reinterpret_cast(value_tensor.flat().data()), reinterpret_cast(value), num_bytes * sizeof(uint8_t)); + (*record_read)++; + while ((*record_read) < record_to_read) { + Status status = state.get()->ReadFrame(&num_bytes, &value, &height, &width); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return status; + } + if (!status.ok()) { + break; + } + int64 offset = (*record_read) * height * width * 3; + std::memcpy(reinterpret_cast(&value_tensor.flat().data()[offset]), reinterpret_cast(value), num_bytes * sizeof(uint8_t)); + (*record_read)++; + } + out_tensors->emplace_back(std::move(value_tensor)); + return Status::OK(); + } + Status FromStream(io::InputStreamInterface* s) override { + return Status::OK(); + } + void EncodeAttributes(VariantTensorData* data) const override { + } + bool DecodeAttributes(const VariantTensorData& data) override { + return true; + } + protected: +}; + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VideoInput, "tensorflow::data::VideoInput"); + +REGISTER_KERNEL_BUILDER(Name("VideoInput").Device(DEVICE_CPU), + FileInputOp); +REGISTER_KERNEL_BUILDER(Name("VideoDataset").Device(DEVICE_CPU), + FileInputDatasetOp); + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow_io/video/ops/dataset_ops.cc b/tensorflow_io/video/ops/video_ops.cc similarity index 62% rename from tensorflow_io/video/ops/dataset_ops.cc rename to tensorflow_io/video/ops/video_ops.cc index f819b6a4a..f9f7b3070 100644 --- a/tensorflow_io/video/ops/dataset_ops.cc +++ b/tensorflow_io/video/ops/video_ops.cc @@ -19,13 +19,28 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("VideoInput") + .Input("source: string") + .Output("handle: variant") + .Attr("filters: list(string) = []") + .Attr("columns: list(string) = []") + .Attr("schema: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + return Status::OK(); + }); + REGISTER_OP("VideoDataset") - .Input("filenames: string") + .Input("input: T") + .Input("batch: int64") .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("T: {string, variant} = DT_VARIANT") .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim(), 3})); - return Status::OK(); - }); + c->set_output(0, c->MakeShape({})); + return Status::OK(); + }); } // namespace tensorflow diff --git a/tensorflow_io/video/python/ops/video_ops.py b/tensorflow_io/video/python/ops/video_ops.py index c34a48059..9d8bcbe8e 100644 --- a/tensorflow_io/video/python/ops/video_ops.py +++ b/tensorflow_io/video/python/ops/video_ops.py @@ -68,10 +68,10 @@ def load_dependency_and_library(p): ], }) -class VideoDataset(data_ops.BaseDataset): +class VideoDataset(data_ops.Dataset): """A Video File Dataset that reads the video file.""" - def __init__(self, filename): + def __init__(self, filename, batch=None): """Create a `VideoDataset`. `VideoDataset` allows a user to read data from a video file with @@ -94,10 +94,11 @@ def __init__(self, filename): Args: filename: A `tf.string` tensor containing one or more filenames. """ - batch = None # TODO: Add batch support - self._batch = 0 if batch is None else batch - self._dtypes = [tf.uint8] - self._shapes = [tf.TensorShape([None, None, 3])] + batch = 0 if batch is None else batch + dtypes = [tf.uint8] + shapes = [ + tf.TensorShape([None, None, 3])] if batch == 0 else [ + tf.TensorShape([None, None, None, 3])] super(VideoDataset, self).__init__( - video_ops.video_dataset(filename), - self._batch, self._dtypes, self._shapes) + video_ops.video_dataset, + video_ops.video_input(filename), batch, dtypes, shapes) diff --git a/tests/test_video.py b/tests/test_video.py index 201409d0f..f39a22406 100644 --- a/tests/test_video.py +++ b/tests/test_video.py @@ -31,7 +31,7 @@ def test_video_predict(): model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') - x = video_io.VideoDataset(video_path).batch(1).map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) + x = video_io.VideoDataset(video_path, batch=1).map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) y = model.predict(x) p = tf.keras.applications.resnet50.decode_predictions(y, top=1) assert len(p) == 166 From fcb998a3edca7ffa9e59d6a1e3377680ff975851 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 12 Jun 2019 22:49:44 +0000 Subject: [PATCH 2/3] Support read video from stream (instead of file name). Fixes 144 Signed-off-by: Yong Tang --- tensorflow_io/video/BUILD | 6 +-- tensorflow_io/video/kernels/ffmpeg.cc | 59 ++++++++++++++++++++++ tensorflow_io/video/kernels/video_input.cc | 3 +- tensorflow_io/video/kernels/video_reader.h | 7 ++- tests/test_video.py | 2 +- tests/test_video_eager.py | 2 +- 6 files changed, 71 insertions(+), 8 deletions(-) diff --git a/tensorflow_io/video/BUILD b/tensorflow_io/video/BUILD index 1a07f6ce9..10038996e 100644 --- a/tensorflow_io/video/BUILD +++ b/tensorflow_io/video/BUILD @@ -18,8 +18,8 @@ cc_binary( includes = ["."], linkshared = 1, deps = [ + "//tensorflow_io/core:dataset_ops", "@ffmpeg_3_4//:ffmpeg", - "//tensorflow_io/core:dataset_ops", ], ) @@ -39,8 +39,8 @@ cc_binary( includes = ["."], linkshared = 1, deps = [ + "//tensorflow_io/core:dataset_ops", "@ffmpeg_2_8//:ffmpeg", - "//tensorflow_io/core:dataset_ops", ], ) @@ -60,7 +60,7 @@ cc_binary( includes = ["."], linkshared = 1, deps = [ + "//tensorflow_io/core:dataset_ops", "@libav_9_20//:libav", - "//tensorflow_io/core:dataset_ops", ], ) diff --git a/tensorflow_io/video/kernels/ffmpeg.cc b/tensorflow_io/video/kernels/ffmpeg.cc index aa7ed0656..e638d2b05 100644 --- a/tensorflow_io/video/kernels/ffmpeg.cc +++ b/tensorflow_io/video/kernels/ffmpeg.cc @@ -34,8 +34,63 @@ namespace tensorflow { namespace data { namespace video { +static int io_read_packet(void *opaque, uint8_t *buf, int buf_size) { + VideoReader *r = (VideoReader *)opaque; + StringPiece result; + Status status = r->stream_->Read(r->offset_, buf_size, &result, (char *)buf); + if (!(status.ok() || errors::IsOutOfRange(status))) { + return -1; + } + r->offset_ += result.size(); + return result.size(); +} + +static int64_t io_seek(void *opaque, int64_t offset, int whence) { + VideoReader *r = (VideoReader *)opaque; + uint64 file_size = 0; + Status status = r->stream_->GetFileSize(&file_size); + if (!status.ok()) { + return -1; + } + switch (whence) + { + case SEEK_SET: + if (offset > file_size) { + return -1; + } + r->offset_ = offset; + return r->offset_; + case SEEK_CUR: + if (r->offset_ + offset > file_size) { + return -1; + } + r->offset_ += offset; + return r->offset_; + case SEEK_END: + if (offset > file_size) { + return -1; + } + r->offset_ = file_size - offset; + return r->offset_; + case AVSEEK_SIZE: + return file_size; + default: + break; + } + return -1; +} + Status VideoReader::ReadHeader() { + // Allocate format + if ((format_context_ = avformat_alloc_context()) == NULL) { + return errors::InvalidArgument("could not allocate format context"); + } + // Allocate context + if ((io_context_ = avio_alloc_context(NULL, 0, 0, this, io_read_packet, NULL, io_seek)) == NULL) { + return errors::InvalidArgument("could not allocate io context"); + } + format_context_->pb = io_context_; // Open input file, and allocate format context if (avformat_open_input(&format_context_, filename_.c_str(), NULL, NULL) < 0) { return errors::InvalidArgument("could not open video file: ", filename_); @@ -206,6 +261,10 @@ VideoReader::~VideoReader() { avcodec_free_context(&codec_context_); #endif avformat_close_input(&format_context_); + av_free(format_context_); + if (io_context_ != NULL) { + av_free(io_context_); + } } } // namespace diff --git a/tensorflow_io/video/kernels/video_input.cc b/tensorflow_io/video/kernels/video_input.cc index 09dd110f1..bc2ccf35b 100644 --- a/tensorflow_io/video/kernels/video_input.cc +++ b/tensorflow_io/video/kernels/video_input.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "kernels/dataset_ops.h" #include "kernels/video_reader.h" namespace tensorflow { @@ -35,7 +34,7 @@ class VideoInput: public FileInput { Status ReadRecord(io::InputStreamInterface* s, IteratorContext* ctx, std::unique_ptr& state, int64 record_to_read, int64* record_read, std::vector* out_tensors) const override { if (state.get() == nullptr) { VideoReaderInit(); - state.reset(new video::VideoReader(filename())); + state.reset(new video::VideoReader(dynamic_cast(s), filename())); TF_RETURN_IF_ERROR(state.get()->ReadHeader()); } // Read the first frame to get height and width diff --git a/tensorflow_io/video/kernels/video_reader.h b/tensorflow_io/video/kernels/video_reader.h index b8059b7c0..15afe5691 100644 --- a/tensorflow_io/video/kernels/video_reader.h +++ b/tensorflow_io/video/kernels/video_reader.h @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/platform/file_system.h" +#include "kernels/dataset_ops.h" extern "C" { @@ -33,7 +34,7 @@ namespace video { class VideoReader { public: - explicit VideoReader(const string &filename) : filename_(filename) {} + explicit VideoReader(SizedRandomAccessInputStreamInterface* s, const string& filename) : stream_(s), filename_(filename) {} Status ReadHeader(); @@ -43,6 +44,9 @@ class VideoReader { virtual ~VideoReader(); + public: + SizedRandomAccessInputStreamInterface* stream_; + int64 offset_ = 0; private: std::string ahead_; std::string filename_; @@ -58,6 +62,7 @@ class VideoReader { AVCodecContext *codec_context_ = 0; AVFrame *frame_ = 0; AVPacket packet_; + AVIOContext *io_context_ = NULL; TF_DISALLOW_COPY_AND_ASSIGN(VideoReader); }; diff --git a/tests/test_video.py b/tests/test_video.py index f39a22406..a179284a3 100644 --- a/tests/test_video.py +++ b/tests/test_video.py @@ -28,7 +28,7 @@ video_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_video", "small.mp4") - +video_path = "file://" + video_path def test_video_predict(): model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') x = video_io.VideoDataset(video_path, batch=1).map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) diff --git a/tests/test_video_eager.py b/tests/test_video_eager.py index af95cbed9..b1b7ebb4d 100644 --- a/tests/test_video_eager.py +++ b/tests/test_video_eager.py @@ -30,7 +30,7 @@ video_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_video", "small.mp4") - +video_path = "file://" + video_path def test_video_dataset(): """test_video_dataset""" num_repeats = 2 From 11e9c0c79b95c36a80dcd0e3418fb635ea8cbfa3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 14 Jun 2019 14:36:00 +0000 Subject: [PATCH 3/3] Address review feedback Signed-off-by: Yong Tang --- tensorflow_io/video/python/ops/video_ops.py | 3 +++ tests/test_video.py | 3 +-- tests/test_video_eager.py | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow_io/video/python/ops/video_ops.py b/tensorflow_io/video/python/ops/video_ops.py index 9d8bcbe8e..cd4c066f1 100644 --- a/tensorflow_io/video/python/ops/video_ops.py +++ b/tensorflow_io/video/python/ops/video_ops.py @@ -93,6 +93,9 @@ def __init__(self, filename, batch=None): Args: filename: A `tf.string` tensor containing one or more filenames. + batch: An integer representing the number of consecutive image frames + to combine in a single batch. If `batch == 0` then each element + of the dataset has one standalone image frame. """ batch = 0 if batch is None else batch dtypes = [tf.uint8] diff --git a/tests/test_video.py b/tests/test_video.py index a179284a3..3686d0a92 100644 --- a/tests/test_video.py +++ b/tests/test_video.py @@ -26,9 +26,8 @@ pytest.skip("video is not supported on macOS yet", allow_module_level=True) import tensorflow_io.video as video_io # pylint: disable=wrong-import-position -video_path = os.path.join( +video_path = "file://" +os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_video", "small.mp4") -video_path = "file://" + video_path def test_video_predict(): model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') x = video_io.VideoDataset(video_path, batch=1).map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) diff --git a/tests/test_video_eager.py b/tests/test_video_eager.py index b1b7ebb4d..766765304 100644 --- a/tests/test_video_eager.py +++ b/tests/test_video_eager.py @@ -28,9 +28,8 @@ pytest.skip("video is not supported on macOS yet", allow_module_level=True) import tensorflow_io.video as video_io # pylint: disable=wrong-import-position -video_path = os.path.join( +video_path = "file://" + os.path.join( os.path.dirname(os.path.abspath(__file__)), "test_video", "small.mp4") -video_path = "file://" + video_path def test_video_dataset(): """test_video_dataset""" num_repeats = 2