Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit ce1b08d

Browse files
authored
RFC: Generalizing tf.data batching using windowing and reducers (#5)
* Create 20180726-tf-data-windowing-reducers.md Proposed RFC from @jsimsa, addresses the known limitations of the current tf.data batching API: * it provides a mechanism for padded batching of sparse tensors * it facilitates customization of batching logic (users can now express batching logic as a pure Python function) * it enables application of different batching logic on different components * Update 20180726-tf-data-windowing-reducers.md Updates following the review committee. * Update RFC metadata Status => accepted. Update revision date.
1 parent 296ed6e commit ce1b08d

File tree

1 file changed

+291
-0
lines changed

1 file changed

+291
-0
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# Generalizing tf.data batching using windowing and reducers
2+
3+
| Status | Accepted |
4+
:---------------|:-----------------------------------------------------|
5+
| **Author(s)** | Jiri Simsa (Google) |
6+
| **Sponsor** | Derek Murray (Google) |
7+
| **Updated** | 2018-09-19 |
8+
9+
## Objective
10+
11+
This proposal addresses the known limitations of the current tf.data batching API:
12+
13+
* it provides a mechanism for padded batching of sparse tensors
14+
* it facilitates customization of batching logic (users can now express batching logic as a pure Python function)
15+
* it enables application of different batching logic on different components
16+
17+
## **Motivation**
18+
19+
The tf.data API is the de facto standard for creating TensorFlow input pipelines, whose purpose is to extract data from a storage system, transform it, and load it onto an accelerator.
20+
21+
A common transformation performed by TensorFlow input pipelines is batching -- combining multiple tensors into a single tensor of higher dimension, most often to make a minibatch for training. Currently, the core tf.data API for batching consists of [batch](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch) and [padded_batch](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#padded_batch). The former assumes the inputs have the same shape and supports both dense and sparse inputs. The latter supports dynamically shaped inputs, such as you might find in sequential data: it assumes the inputs have the same rank but not necessarily the same shape and can pad differently shaped inputs to a common shape; only dense inputs are supported by padded_batch.
22+
23+
The tf.data batching API has several limitations that has surfaced in various users requests:
24+
25+
* As already mentioned, the padded_batch transformation does not support sparse tensors inputs ([issue](https://github.com/tensorflow/tensorflow/issues/18302)).
26+
* The current API is not flexible enough to accept user-provided batching logic (e.g. [issue](https://github.com/tensorflow/tensorflow/issues/20391)).
27+
* The same batching logic needs to be applied to all components of the input dataset, which is not always desirable (e.g. [issue](https://github.com/tensorflow/tensorflow/issues/20391)). Users can work around this limitation by creating separate datasets to which different batching transformations are applied and then zipping the datasets; however, this can be inefficient, unergonomic, and error prone.
28+
29+
30+
## Proposal
31+
32+
This document proposes leveraging the recently introduced support for _nested_ datasets as inputs to tf.data transformations to perform generalized batching as follows:
33+
34+
35+
36+
1. A <span style="text-decoration:underline;">window</span> transformation is used to combine consecutive elements of the input into a nested dataset (as opposed to a higher dimensional tensor).
37+
1. A map transformation is used to, on a per-component basis, apply a suitable <span style="text-decoration:underline;">reducer</span> which transforms the nested dataset to a batched tensor.
38+
39+
The underlined transformations do not exist and are the proposed extensions to the tf.data API.
40+
41+
42+
### Windowing
43+
44+
Windowing combines elements of a dataset into finite datasets referred to as windows. This is similar to batching, with the main difference being that batching combines elements of dataset into a higher dimensional element, while windowing combines the elements to a dataset.
45+
46+
47+
```python
48+
def window(size, shift=1, stride=1, drop_remainder=True):
49+
"""Combines input elements into a dataset of windows.
50+
51+
Each window is a dataset itself and contains `size` elements (or
52+
possibly less if there is not enough input elements to fill the window
53+
and `drop_remainder` evaluates to false).
54+
55+
The `stride` argument determines the stride of the input elements,
56+
and the `shift` argument determines the shift of the window.
57+
58+
For example:
59+
- tf.data.range(5).window(3) produces {{0, 1, 2}, {1, 2, 3}, {2, 3, 4}}
60+
- tf.data.range(5).window(3, 3, 1, False) produces {{0, 1, 2}, {3, 4}}
61+
- tf.data.range(6).window(3, 1, 2) produces {{0, 2, 4}, {1, 3, 5}}
62+
63+
Args:
64+
size: A `tf.int64` scalar `tf.Tensor`, representing the number
65+
of elements of the input dataset to combine into a window.
66+
shift: A `tf.int64` scalar `tf.Tensor`, representing the forward
67+
shift of the sliding window in each iteration.
68+
stride: A `tf.int64` scalar `tf.Tensor`, representing the stride
69+
of the input elements in the sliding window.
70+
drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether
71+
a window should be dropped in case its size is smaller than
72+
`window_size`.
73+
74+
Returns:
75+
Dataset: A `Dataset` whose elements are a `Dataset`.
76+
"""
77+
```
78+
79+
For example:
80+
81+
* `tf.data.range(5).window(3)` produces `{{0, 1, 2}, {1, 2, 3}, {2, 3, 4}}`.
82+
* `tf.data.range(5).window(3, 3, 1, False)` produces `{{0, 1, 2}, {3, 4}}`.
83+
* `tf.data.range(6).window(3, 1, 2)` produces `{{0, 2, 4}, {1, 3, 5}}`.
84+
85+
86+
### Reducers
87+
88+
89+
#### Example 0: Count Elements
90+
91+
To introduce the concept of tf.data reducers to readers unfamiliar with it, we illustrate how a reducer can be used to count the elements of a dataset:
92+
93+
94+
```python
95+
def count(dataset):
96+
"""Counts the elements of a dataset."""
97+
98+
def init_fn(_):
99+
return 0
100+
101+
def reduce_fn(state, value):
102+
return state + 1
103+
104+
def finalize_fn(state):
105+
return state
106+
107+
count_reducer = tf.data.Reducer(init_fn, reduce_fn, finalize_fn)
108+
return dataset.reduce(count_reducer)
109+
110+
value = count(tf.data.Dataset.range(10))
111+
with tf.Session() as sess:
112+
print(sess.run(value)) # produces 10
113+
```
114+
115+
116+
As you can see, a tf.data reducer consists of three functions: 1) an _init()_ function that sets up the initial state, which can be an arbitrary nest of tensor-like objects, 2) a _reduce()_ function that defines how to update the intermediate state given the value of the next element, and 3) a _finalize()_ function that defines how to produce the transform the final state into the output value.
117+
118+
The reducer inputs an entire dataset and reduces it to a single value. This single value is the result of taking the output of init(), calling reduce() successively on every element of the dataset until the dataset is exhausted, and then calling finalize() on the result.
119+
120+
121+
#### Example 1: Batch of Dense Tensors
122+
123+
Next, we illustrate how tf.data reducers can be used to create a batch from a dataset of dense tensors.
124+
125+
```python
126+
def batch_dense(dataset):
127+
"""Batches a dataset of dense tensors."""
128+
129+
if dataset.output_shapes.is_fully_defined():
130+
shape = dataset.output_shapes
131+
else:
132+
first_element = tf.contrib.data.get_single_element(dataset.take(1))
133+
shape = tf.shape(first_element)
134+
135+
def batch_init_fn(_):
136+
"""Return an empty Tensor of the correct shape and type."""
137+
batch_shape = tf.concat([[0], shape], 0)
138+
return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
139+
140+
def batch_reduce_fn(state, value):
141+
"""Append this value to what we have of the batch so far."""
142+
return tf.concat([state, [value]], 0)
143+
144+
def batch_finalize_fn(state):
145+
"""Return the batch tensor as constructed so far."""
146+
return state
147+
148+
batch_reducer = tf.data.Reducer(batch_init_fn, batch_reduce_fn,
149+
batch_finalize_fn)
150+
return dataset.reduce(batch_reducer)
151+
152+
batch = batch_dense(tf.data.Dataset.range(5))
153+
with tf.Session() as sess:
154+
print(sess.run(batch)) # produces [0 1 2 3 4]
155+
156+
```
157+
158+
159+
160+
#### Example 2: Padded Batch of Dense Tensors
161+
162+
Our next tf.data reducer example illustrates how to use a reducer to create a padded batch from a dataset of dense tensors.
163+
164+
```python
165+
def padded_batch_dense(dataset, padded_shape, padding_value):
166+
"""Batches a dataset of dense tensors with padding."""
167+
168+
padded_shape = tf.cast(
169+
convert.partial_shape_to_tensor(padded_shape), tf.int32)
170+
171+
def init_fn(_):
172+
return 0, padded_shape
173+
174+
def reduce_fn(state, value):
175+
count, shape = state
176+
return count + 1, tf.maximum(shape, tf.shape(value))
177+
178+
def finalize_fn(state):
179+
return state
180+
181+
# Compute the padded shape and count elements.
182+
reducer = tf.contrib.Reducer(init_fn, reduce_fn, finalize_fn)
183+
count, padded_shape = dataset.reduce(reducer)
184+
185+
def pad_fn(value):
186+
shape = tf.shape(value)
187+
left = tf.zeros_like(shape)
188+
right = padded_shape - shape
189+
return tf.pad(value, tf.stack([left, right], 1),
190+
constant_values=padding_value)
191+
192+
return dataset.map(pad_fn).batch(count)
193+
194+
padded_batch = padded_batch_dense(
195+
tf.data.Dataset.from_tensor_slices([[1], [2]]), [2], 0))
196+
.make_one_shot_iterator().get_next()
197+
with tf.Session() as sess:
198+
print(sess.run(padded_batch)) # produces [[1 0] [2 0]]
199+
```
200+
201+
202+
203+
### End-to-end Example
204+
205+
Finally, we illustrate how to use the window transformation to perform generalized tf.data batching:
206+
207+
```python
208+
import tensorflow as tf
209+
210+
def gen():
211+
yield ('a', [1])
212+
yield ('b', [2])
213+
yield ('c', [3])
214+
yield ('d', [4, 4])
215+
216+
def map_fn(a, b):
217+
return tf.data.Dataset.zip((a.batch(2), b.padded_batch(2, [2])))
218+
219+
dataset = tf.data.Dataset.from_generator(gen, (tf.string, tf.int32))
220+
dataset = dataset.window(2, 2).flat_map(map_fn)
221+
get_next = dataset.make_one_shot_iterator().get_next()
222+
223+
with tf.Session() as sess:
224+
print(sess.run(get_next)) # produces (['a', 'b'], [[1, 0], [2, 0]])
225+
print(sess.run(get_next)) # produces (['c', 'd'], [[3, 0], [4, 4]])
226+
```
227+
228+
229+
230+
## API Changes
231+
232+
This design document proposes the following changes to the tf.data API:
233+
234+
* Adding a `tf.data.Dataset.window` method, which provides the windowing functionality described in this proposal.
235+
* Promoting the `tf.contrib.data.reduce_dataset()` method to `tf.data.Dataset.reduce()` and the `tf.contrib.data.Reducer` class to `tf.data.Reducer`.
236+
* Allowing nested datasets as inputs of `map` and `filter`.
237+
* Adding canned reducers for padded batching of dense and sparse tensors to `tf.contrib.data`, changing implementation of `tf.data.Dataset.padded_batch()` to use these, and marking it as deprecated.
238+
239+
## Summary
240+
241+
This proposal addresses known limitations of the current tf.data batching API:
242+
243+
* it provides a mechanism for padded batching of sparse tensors
244+
* it facilitates customization of batching logic (users can now express batching logic as a pure Python function)
245+
* it enables application of different batching logic on different components
246+
247+
248+
## Discussion Notes
249+
250+
See also notes from [public review](https://github.com/tensorflow/community/pull/5). The following notes were taken in the review committee.
251+
252+
Q: What is the better value added by the new examples?
253+
254+
A: The previous examples were inefficient versions of things that already exist.
255+
256+
Q: The obvious use of the API led to an inefficient implementation (of batching, using tf.concat()). It might be hard to write batching in this API without it being
257+
258+
A: This API is not meant to be used to implement something that already exists.
259+
260+
Q: Is this not a good API for implementing batching? The structure encourages inefficient implementations.
261+
262+
A: The point was not to illustrate how we do batching efficiently. It's already done.
263+
264+
Q: I thought the point was to show many different ways to do batching.
265+
266+
A: The base case is still an efficient implementation of batch, but we can add other logic around it (e.g. to do different forms of padding, etc.).
267+
268+
Q: What were the biggest questions?
269+
270+
A: Batching efficiency was the biggest one. Some questions about the signature of the newly introduced transformation. One reader commented that the meaning of "window" in other communities (video processing) typically includes some notion of slide/stride. Conclusion was that we will support shift and stride as we already do in `sliding_window_batch()`. Stride = number of elements you skip (i.e. for non-consecutive elements in a window), shift = how much the window shifts between windows.
271+
272+
Q: Is there any significant overhead from elements being datasets (e.g. from extra work in Python)?
273+
274+
A: The amount of computation that you have to do to compute the batch should be the same. There is no additional work in Python.
275+
276+
Q: How do you compile the reduce function to run it in C++?
277+
278+
A: It's a TF function, similar to existing map functions, etc.
279+
280+
Q: Concern about how many times count() is invoked.
281+
282+
A: The example shows how to use it in a filter(), where the count is evaluated in a function context.
283+
284+
Q: Re: runtime efficiency, in the higher dimensional case, would we always make a copy to concatenate?
285+
286+
A: That's what the Dataset.batch() transformation does. The nested dataset elements aren't intended for direct consumption, but to serve as input to other transformations, which e.g. build padded batches, sparse tensors, etc. This proposal lets you mix and match how you treat the different components, as illustrated in the end-to-end example. The goal of the new API isn't to improve efficiency of the existing implementations, but to add support for new kinds of transformation.
287+
288+
Q: What about the parallel proposal for random access datasets? Will count() be an exposed primitive or would you use the efficient random-access count?
289+
290+
A: We would add efficient random-access count for the nested datasets produced by window().
291+

0 commit comments

Comments
 (0)