Skip to content

Commit e1b65ed

Browse files
committed
Update to use IOTensor for default.
Signed-off-by: Yong Tang <[email protected]>
1 parent b8db890 commit e1b65ed

File tree

1 file changed

+126
-113
lines changed

1 file changed

+126
-113
lines changed

tensorflow_io/core/python/ops/io_tensor.py

Lines changed: 126 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,85 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import collections
21+
2022
import tensorflow as tf
2123
from tensorflow_io.core.python.ops import core_ops
2224
from tensorflow_io.kafka.python.ops.kafka_ops import kafka_ops
2325

24-
class IOTensor(object):
26+
class _IOBaseTensor(object):
27+
"""_IOBaseTensor"""
28+
29+
def __init__(self,
30+
dtype,
31+
shape,
32+
properties,
33+
internal=False):
34+
if not internal:
35+
raise ValueError("IOTensor constructor is private; please use one "
36+
"of the factory methods instead (e.g., "
37+
"IOTensor.from_tensor())")
38+
self._dtype = dtype
39+
self._shape = shape
40+
self._properties = collections.OrderedDict(
41+
{} if properties is None else properties)
42+
super(_IOBaseTensor, self).__init__()
43+
44+
#=============================================================================
45+
# Accessors
46+
#=============================================================================
47+
48+
@property
49+
def dtype(self):
50+
"""The `DType` of values in this tensor."""
51+
return self._dtype
52+
53+
@property
54+
def shape(self):
55+
"""The statically known shape of this io tensor.
56+
57+
Returns:
58+
A `TensorShape` containing the statically known shape of this io
59+
tensor. The first dimension could have a size of `None` if this
60+
io tensor is from an iterable.
61+
62+
Examples:
63+
64+
```python
65+
```
66+
"""
67+
return self._shape
68+
69+
@property
70+
def rank(self):
71+
"""The number of dimensions in this io tensor.
72+
73+
Returns:
74+
A Python `int` indicating the number of dimensions in this io
75+
tensor.
76+
"""
77+
return tf.rank(self._shape)
78+
79+
@property
80+
def properties(self):
81+
"""The properties associated with this tensor.
82+
83+
Returns:
84+
A ordered dict with name and properties associated with this tensor.
85+
"""
86+
return self._properties
87+
88+
#=============================================================================
89+
# String Encoding
90+
#=============================================================================
91+
def __repr__(self):
92+
props = "".join([
93+
", %s: %s" % (k, repr(v)) for (k, v) in self.properties.items()])
94+
return "<%s: shape=%s, dtype=%s%s>" % (
95+
type(self).__name__, self.shape, self.dtype.name, props)
96+
97+
98+
class IOTensor(_IOBaseTensor):
2599
"""IOTensor"""
26100

27101
#=============================================================================
@@ -30,6 +104,9 @@ class IOTensor(object):
30104
def __init__(self,
31105
dtype,
32106
shape,
107+
resource,
108+
function,
109+
properties,
33110
internal=False):
34111
"""Creates an `IOTensor`.
35112
@@ -42,19 +119,44 @@ def __init__(self,
42119
Args:
43120
dtype: The type of the tensor.
44121
shape: The shape of the tensor.
122+
resource: The resource associated with the IO.
123+
function: The function for indexing and accessing items with resource.
124+
properties: An ordered dict of properties to be printed.
45125
internal: True if the constructor is being called by one of the factory
46126
methods. If false, an exception will be raised.
47127
48128
Raises:
49129
TypeError:
50130
"""
51-
if not internal:
52-
raise ValueError("IOTensor constructor is private; please use one "
53-
"of the factory methods instead (e.g., "
54-
"IOTensor.from_tensor())")
55-
self._dtype = dtype
56-
self._shape = shape
57-
super(IOTensor, self).__init__()
131+
self._resource = resource
132+
self._function = function
133+
super(IOTensor, self).__init__(dtype, shape, properties, internal=internal)
134+
135+
#=============================================================================
136+
# Indexing & Slicing
137+
#=============================================================================
138+
def __getitem__(self, key):
139+
"""Returns the specified piece of this IOTensor."""
140+
if isinstance(key, slice):
141+
start = key.start
142+
stop = key.stop
143+
step = key.step
144+
if start is None:
145+
start = 0
146+
if stop is None:
147+
stop = -1
148+
if step is None:
149+
step = 1
150+
else:
151+
start = key
152+
stop = key + 1
153+
step = 1
154+
return self._function(
155+
self._resource, start, stop, step, dtype=self.dtype)
156+
157+
def __len__(self):
158+
"""Returns the total number of items of this IOTensor."""
159+
return abs(self.shape[0])
58160

59161
#=============================================================================
60162
# Factory Methods
@@ -85,7 +187,7 @@ def from_kafka(cls,
85187
subscription,
86188
servers=None,
87189
timeout=None,
88-
eof=False,
190+
eof=True,
89191
conf=None,
90192
**kwargs):
91193
"""Creates an `IOTensor` from a Kafka stream.
@@ -114,42 +216,6 @@ def from_kafka(cls,
114216
return KafkaIOTensor(
115217
subscription, servers, timeout, eof, conf, internal=True)
116218

117-
#=============================================================================
118-
# Accessors
119-
#=============================================================================
120-
121-
@property
122-
def dtype(self):
123-
"""The `DType` of values in this tensor."""
124-
return self._dtype
125-
126-
@property
127-
def shape(self):
128-
"""The statically known shape of this io tensor.
129-
130-
Returns:
131-
A `TensorShape` containing the statically known shape of this io
132-
tensor. The first dimension could have a size of `None` if this
133-
io tensor is from an iterable.
134-
135-
Examples:
136-
137-
```python
138-
```
139-
"""
140-
return self._shape
141-
142-
@property
143-
def rank(self):
144-
"""The number of dimensions in this io tensor.
145-
146-
Returns:
147-
A Python `int` indicating the number of dimensions in this io
148-
tensor.
149-
"""
150-
return tf.rank(self._shape)
151-
152-
153219
#=============================================================================
154220
# Tensor Type Conversions
155221
#=============================================================================
@@ -195,14 +261,7 @@ def to_tensor(self, **kwargs):
195261
with tf.name_scope(kwargs.get("name", "IOToTensor")):
196262
return self.__getitem__(slice(None, None))
197263

198-
#=============================================================================
199-
# String Encoding
200-
#=============================================================================
201-
def __repr__(self):
202-
return "<%s: shape=%s, dtype=%s>" % (
203-
type(self).__name__, self.shape, self.dtype.name)
204-
205-
class IOIterableTensor(IOTensor):
264+
class IOIterableTensor(_IOBaseTensor):
206265
"""IOIterableTensor"""
207266

208267
#=============================================================================
@@ -211,62 +270,21 @@ class IOIterableTensor(IOTensor):
211270
def __init__(self,
212271
dtype,
213272
shape,
214-
iter_func,
273+
function,
274+
properties,
215275
internal=False):
216276
"""Creates an `IOIterableTensor`. """
217-
self._iter_func = iter_func
218-
super(IOIterableTensor, self).__init__(dtype, shape, internal=internal)
277+
self._function = function
278+
super(IOIterableTensor, self).__init__(
279+
dtype, shape, properties, internal=internal)
219280

220281
#=============================================================================
221282
# Iterator
222283
#=============================================================================
223284
def __iter__(self):
224-
return self._iter_func()
225-
226-
class IOIndexableTensor(IOTensor):
227-
"""IOIndexableTensor"""
228-
229-
#=============================================================================
230-
# Constructor (private)
231-
#=============================================================================
232-
def __init__(self,
233-
dtype,
234-
shape,
235-
resource,
236-
getitem_func,
237-
internal=False):
238-
"""Creates an `IOIndexableTensor`. """
239-
self._resource = resource
240-
self._getitem_func = getitem_func
241-
super(IOIndexableTensor, self).__init__(dtype, shape, internal=internal)
285+
return self._function()
242286

243-
#=============================================================================
244-
# Indexing & Slicing
245-
#=============================================================================
246-
def __getitem__(self, key):
247-
"""Returns the specified piece of this IOTensor."""
248-
if isinstance(key, slice):
249-
start = key.start
250-
stop = key.stop
251-
step = key.step
252-
if start is None:
253-
start = 0
254-
if stop is None:
255-
stop = -1
256-
if step is None:
257-
step = 1
258-
else:
259-
start = key
260-
stop = key + 1
261-
step = 1
262-
return self._getitem_func(
263-
self._resource, start, stop, step, dtype=self.dtype)
264-
265-
def __len__(self):
266-
"""Returns the total number of items of this IOTensor."""
267-
return abs(self.shape[0])
268-
269-
class AudioIOTensor(IOIndexableTensor):
287+
class AudioIOTensor(IOTensor):
270288
"""AudioIOTensor"""
271289

272290
#=============================================================================
@@ -281,10 +299,12 @@ def __init__(self,
281299
dtype = tf.as_dtype(dtypes[0].numpy())
282300
shape = tf.TensorShape([
283301
None if dim < 0 else dim for dim in shapes[0].numpy() if dim != 0])
302+
properties = collections.OrderedDict({"rate": rate.numpy()})
284303
self._rate = rate.numpy()
285304
super(AudioIOTensor, self).__init__(
286305
dtype, shape,
287306
resource, core_ops.wav_indexable_get_item,
307+
properties,
288308
internal=internal)
289309

290310
#=============================================================================
@@ -296,14 +316,6 @@ def rate(self):
296316
"""The sampel `rate` of the audio stream"""
297317
return self._rate
298318

299-
#=============================================================================
300-
# String Encoding
301-
#=============================================================================
302-
def __repr__(self):
303-
return "<%s: shape=%s, dtype=%s, rate=%s>" % (
304-
type(self).__name__, self.shape, self.dtype.name, self.rate)
305-
306-
307319
class KafkaIOTensor(IOIterableTensor):
308320
"""KafkaIOTensor"""
309321

@@ -317,7 +329,7 @@ def __init__(self,
317329
dtype = tf.string
318330
shape = tf.TensorShape([None])
319331

320-
def _iter_func():
332+
def function():
321333
"""_iter_func"""
322334
metadata = []
323335
if servers is not None:
@@ -326,8 +338,9 @@ def _iter_func():
326338
metadata.append("conf.timeout=%d" % timeout)
327339
if eof is not None:
328340
metadata.append("conf.eof=%d" % (1 if eof else 0))
329-
for e in conf:
330-
metadata.append(e)
341+
if conf is not None:
342+
for e in conf:
343+
metadata.append(e)
331344
resource, _, _ = kafka_ops.kafka_iterable_init(
332345
subscription, metadata=metadata,
333346
container=scope, shared_name=subscription)
@@ -339,4 +352,4 @@ def _iter_func():
339352
yield value
340353

341354
super(KafkaIOTensor, self).__init__(
342-
dtype, shape, _iter_func, internal=internal)
355+
dtype, shape, function, properties=None, internal=internal)

0 commit comments

Comments
 (0)