@@ -41,6 +41,122 @@ class IOIndexableInterface : public IOInterface {
4141 virtual Status GetItem (const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) = 0;
4242};
4343
44+ template <typename Type>
45+ class IOIndexableImplementation : public IOIndexableInterface {
46+ public:
47+ IOIndexableImplementation<Type>(Env* env)
48+ : env_(env)
49+ , iterable_(new Type(env)) {}
50+
51+ ~IOIndexableImplementation<Type>() {}
52+ Status Init (const std::vector<string>& input, const std::vector<string>& metadata, const void * memory_data, const int64 memory_size) override {
53+
54+ TF_RETURN_IF_ERROR (iterable_->Init (input, metadata, memory_data, memory_size));
55+ TF_RETURN_IF_ERROR (iterable_->Spec (dtypes_, shapes_));
56+
57+ const int64 capacity = 4096 ;
58+ std::vector<TensorShape> chunk_shapes;
59+ for (size_t component = 0 ; component < shapes_.size (); component++) {
60+ gtl::InlinedVector<int64, 4 > dims = shapes_[component].dim_sizes ();
61+ dims[0 ] = capacity;
62+ chunk_shapes.push_back (TensorShape (dims));
63+ }
64+
65+ int64 total = 0 ;
66+
67+ int64 record_read = 0 ;
68+ do {
69+ tensors_.push_back (std::vector<Tensor>());
70+ for (size_t component = 0 ; component < shapes_.size (); component++) {
71+ tensors_.back ().push_back (Tensor (dtypes_[component], chunk_shapes[component]));
72+ }
73+ TF_RETURN_IF_ERROR (iterable_->Next (capacity, tensors_.back (), &record_read));
74+ if (record_read == 0 ) {
75+ tensors_.pop_back ();
76+ break ;
77+ }
78+ if (record_read < capacity) {
79+ for (size_t component = 0 ; component < shapes_.size (); component++) {
80+ tensors_.back ()[component] = tensors_.back ()[component].Slice (0 , record_read);
81+ }
82+ }
83+ total += record_read;
84+ } while (record_read != 0 );
85+ for (size_t component = 0 ; component < shapes_.size (); component++) {
86+ shapes_[component].set_dim (0 , total);
87+ }
88+ return Status::OK ();
89+ }
90+ virtual Status Spec (std::vector<DataType>& dtypes, std::vector<PartialTensorShape>& shapes) override {
91+ for (size_t component = 0 ; component < dtypes_.size (); component++) {
92+ dtypes.push_back (dtypes_[component]);
93+ }
94+ for (size_t component = 0 ; component < shapes_.size (); component++) {
95+ shapes.push_back (shapes_[component]);
96+ }
97+ return Status::OK ();
98+ }
99+
100+ Status Extra (std::vector<Tensor>* extra) override {
101+ return iterable_->Extra (extra);
102+ }
103+ string DebugString () const override {
104+ mutex_lock l (mu_);
105+ return strings::StrCat (" IOIndexableImplementation<" , iterable_->DebugString (), " >[]" );
106+ }
107+
108+ Status GetItem (const int64 start, const int64 stop, const int64 step, std::vector<Tensor>& tensors) override {
109+ if (step != 1 ) {
110+ return errors::InvalidArgument (" step != 1 is not supported: " , step);
111+ }
112+ // Find first chunk
113+
114+ int64 chunk_index = 0 ;
115+ int64 chunk_element = -1 ;
116+ int64 current_element = 0 ;
117+ while (chunk_index < tensors_.size ()) {
118+ if (current_element <= start && start < current_element + tensors_[chunk_index][0 ].shape ().dim_size (0 )) {
119+ chunk_element = start - current_element;
120+ current_element = start;
121+ break ;
122+ }
123+ current_element += tensors_[chunk_index][0 ].shape ().dim_size (0 );
124+ chunk_index++;
125+ }
126+ if (chunk_element < 0 ) {
127+ return errors::InvalidArgument (" start is out of range: " , start);
128+ }
129+ std::vector<Tensor> elements;
130+ for (size_t component = 0 ; component < shapes_.size (); component++) {
131+ TensorShape shape (shapes_[component].dim_sizes ());
132+ shape.RemoveDim (0 );
133+ elements.push_back (Tensor (dtypes_[component], shape));
134+ }
135+
136+ while (current_element < stop) {
137+ for (size_t component = 0 ; component < shapes_.size (); component++) {
138+ batch_util::CopySliceToElement (tensors_[chunk_index][component], &elements[component], chunk_element);
139+ batch_util::CopyElementToSlice (elements[component], &tensors[component], (current_element - start));
140+ }
141+ chunk_element++;
142+ if (chunk_element == tensors_[chunk_index][0 ].shape ().dim_size (0 )) {
143+ chunk_index++;
144+ chunk_element = 0 ;
145+ }
146+ current_element++;
147+ }
148+ return Status::OK ();
149+ }
150+ private:
151+ mutable mutex mu_;
152+ Env* env_ GUARDED_BY (mu_);
153+ std::unique_ptr<Type> iterable_ GUARDED_BY (mu_);
154+ std::vector<DataType> dtypes_ GUARDED_BY (mu_);
155+ std::vector<PartialTensorShape> shapes_ GUARDED_BY (mu_);
156+ std::vector<std::vector<Tensor>> tensors_;
157+ };
158+
159+
44160template <typename Type>
45161class IOInterfaceInitOp : public ResourceOpKernel <Type> {
46162 public:
0 commit comments