|
17 | 17 | from __future__ import division |
18 | 18 | from __future__ import print_function |
19 | 19 |
|
| 20 | +import sys |
20 | 21 | import collections |
21 | 22 |
|
22 | 23 | import tensorflow as tf |
@@ -325,7 +326,55 @@ def __init__(self, |
325 | 326 | # Iterator |
326 | 327 | #============================================================================= |
327 | 328 | def __iter__(self): |
328 | | - return self._function() |
| 329 | + resource = self._function["init"](self._function["data"]) |
| 330 | + capacity = 1 |
| 331 | + while True: |
| 332 | + value = self._function["next"](resource, capacity=capacity) |
| 333 | + if tf.shape(value)[0].numpy() < capacity: |
| 334 | + return |
| 335 | + yield value |
| 336 | + |
| 337 | + #============================================================================= |
| 338 | + # Dataset Conversions |
| 339 | + #============================================================================= |
| 340 | + |
| 341 | + def to_dataset(self): |
| 342 | + """Converts this `IOIterableTensor` into a `tf.data.Dataset`. |
| 343 | +
|
| 344 | + Example: |
| 345 | +
|
| 346 | + ```python |
| 347 | + ``` |
| 348 | +
|
| 349 | + Args: |
| 350 | +
|
| 351 | + Returns: |
| 352 | + A `tf.data.Dataset` with value obtained from this `IOIterableTensor`. |
| 353 | + """ |
| 354 | + class _IOIterableTensorDataset(data_ops.BaseDataset): |
| 355 | + """_IOIterableTensorDataset""" |
| 356 | + |
| 357 | + def __init__(self, dtype, shape, function): |
| 358 | + func_init = function["init"] |
| 359 | + func_next = function["next"] |
| 360 | + func_data = function["data"] |
| 361 | + resource = func_init(func_data) |
| 362 | + capacity = 4096 |
| 363 | + dataset = data_ops.BaseDataset.range( |
| 364 | + 0, sys.maxsize, capacity).map( |
| 365 | + lambda i: func_next(resource, capacity)).apply( |
| 366 | + tf.data.experimental.take_while( |
| 367 | + lambda v: tf.greater(tf.shape(v)[0], 0))).apply( |
| 368 | + tf.data.experimental.unbatch()) |
| 369 | + self._dataset = dataset |
| 370 | + self._resource = resource |
| 371 | + self._function = function |
| 372 | + shape = shape[1:] |
| 373 | + super(_IOIterableTensorDataset, self).__init__( |
| 374 | + self._dataset._variant_tensor, [dtype], [shape]) # pylint: disable=protected-access |
| 375 | + |
| 376 | + return _IOIterableTensorDataset( |
| 377 | + self._dtype, self._shape, self._function) |
329 | 378 |
|
330 | 379 | class AudioIOTensor(IOTensor): |
331 | 380 | """AudioIOTensor""" |
@@ -369,30 +418,31 @@ def __init__(self, |
369 | 418 | subscription, servers, timeout, eof, conf, |
370 | 419 | internal=False): |
371 | 420 | with tf.name_scope("KafkaIOTensor") as scope: |
372 | | - dtype = tf.string |
373 | | - shape = tf.TensorShape([None]) |
374 | 421 |
|
375 | | - def function(): |
376 | | - """_iter_func""" |
377 | | - metadata = [] |
378 | | - if servers is not None: |
379 | | - metadata.append("bootstrap.servers=%s" % servers) |
380 | | - if timeout is not None: |
381 | | - metadata.append("conf.timeout=%d" % timeout) |
382 | | - if eof is not None: |
383 | | - metadata.append("conf.eof=%d" % (1 if eof else 0)) |
384 | | - if conf is not None: |
385 | | - for e in conf: |
386 | | - metadata.append(e) |
| 422 | + metadata = [] |
| 423 | + if servers is not None: |
| 424 | + metadata.append("bootstrap.servers=%s" % servers) |
| 425 | + if timeout is not None: |
| 426 | + metadata.append("conf.timeout=%d" % timeout) |
| 427 | + if eof is not None: |
| 428 | + metadata.append("conf.eof=%d" % (1 if eof else 0)) |
| 429 | + if conf is not None: |
| 430 | + for e in conf: |
| 431 | + metadata.append(e) |
| 432 | + |
| 433 | + func_data = {"subscription": subscription, "metadata": metadata} |
| 434 | + def func_init(data): |
| 435 | + """func_init""" |
387 | 436 | resource, _, _ = kafka_ops.kafka_iterable_init( |
388 | | - subscription, metadata=metadata, |
| 437 | + data["subscription"], metadata=data["metadata"], |
389 | 438 | container=scope, shared_name=subscription) |
390 | | - capacity = 1 |
391 | | - while True: |
392 | | - value = kafka_ops.kafka_iterable_next(resource, capacity=capacity) |
393 | | - if tf.shape(value)[0].numpy() < capacity: |
394 | | - return |
395 | | - yield value |
| 439 | + return resource |
| 440 | + func_next = kafka_ops.kafka_iterable_next |
| 441 | + |
| 442 | + dtype = tf.string |
| 443 | + shape = tf.TensorShape([None]) |
396 | 444 |
|
397 | 445 | super(KafkaIOTensor, self).__init__( |
398 | | - dtype, shape, function, properties=None, internal=internal) |
| 446 | + dtype, shape, |
| 447 | + {"init": func_init, "next": func_next, "data": func_data}, |
| 448 | + properties=None, internal=internal) |
0 commit comments