From 27297e6754158f5d003be4c57e22071d7182ae3b Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 7 Feb 2019 11:19:28 +0000 Subject: [PATCH] Switch DistStrat revised API examples to TensorFlow 2 style. --- rfcs/20181016-replicator.md | 174 +++++++++++++++++++----------------- 1 file changed, 92 insertions(+), 82 deletions(-) diff --git a/rfcs/20181016-replicator.md b/rfcs/20181016-replicator.md index 505e6e27b..16f52de56 100644 --- a/rfcs/20181016-replicator.md +++ b/rfcs/20181016-replicator.md @@ -691,76 +691,84 @@ Below is a simple usage example for an image classification use case. ```python with strategy.scope(): - model = resnet.ResNetV1(resnet.BLOCKS_50) - optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9) + model = tf.keras.applications.ResNet50(weights=None) + optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.9) def input_fn(ctx): return imagenet.ImageNet(ctx.get_per_replica_batch_size(effective_batch_size)) -def step_fn(inputs): - image, label = inputs +input_iterator = strategy.make_input_iterator(input_fn) - logits = model(images) - cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=label) - loss = tf.reduce_mean(cross_entropy) - train_op = optimizer.minimize(loss) - with tf.control_dependencies([train_op]): - return tf.identity(loss) +@tf.function +def train_step(): + def step_fn(inputs): + image, label = inputs -input_iterator = strategy.make_input_iterator(input_fn) -per_replica_losses = strategy.run(step_fn, input_iterator) -mean_loss = strategy.reduce(per_replica_losses) - -with tf.Session(config=session_config) as session: - session.run(strategy.initialize()) - session.run(input_iterator.initialize()) - for _ in range(num_train_steps): - loss = session.run(mean_loss) - session.run(strategy.finalize()) + with tf.GradientTape() as tape: + logits = model(images) + cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=label) + loss = tf.reduce_mean(cross_entropy) + + grads = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(list(zip(grads, model.trainable_variables))) + return loss + + per_replica_losses = strategy.run(step_fn, input_iterator) + mean_loss = strategy.reduce(AggregationType.MEAN, per_replica_losses) + return mean_loss + +strategy.initialize() +input_iterator.initialize() +for _ in range(num_train_steps): + loss = train_step() +strategy.finalize() ``` #### Evaluation ```python with strategy.scope(): - model = resnet.ResNetV1(resnet.BLOCKS_50) + model = tf.keras.applications.ResNet50(weights=None) def eval_input_fn(ctx): del ctx # Unused. return imagenet.ImageNet( eval_batch_size, subset="valid", shuffle=False, num_epochs=1) -def eval_top1_accuracy(inputs): - image, label = inputs - logits = model(images) - predicted_label = tf.argmax(logits, axis=1) - top_1_acc = tf.reduce_mean( - tf.cast(tf.equal(predicted_label, label), tf.float32)) - return top1_acc - eval_input_iterator = strategy.make_input_iterator( eval_input_fn, input_replication_mode=InputReplicationMode.SINGLE) -per_replica_top1_accs = strategy.run(eval_top1_accuracy, eval_input_iterator) -mean_top1_acc = strategy.reduce(per_replica_top1_accs) -with tf.Session(config=session_config) as session: - session.run(strategy.initialize()) +@tf.function +def eval(): + def eval_top1_accuracy(inputs): + image, label = inputs + logits = model(images) + predicted_label = tf.argmax(logits, axis=1) + top_1_acc = tf.reduce_mean( + tf.cast(tf.equal(predicted_label, label), tf.float32)) + return top1_acc + + per_replica_top1_accs = strategy.run(eval_top1_accuracy, eval_input_iterator) + mean_top1_acc = strategy.reduce(AggregationType.MEAN, per_replica_top1_accs) + return mean_top1_acc + +strategy.initialize() +while True: + while not has_new_checkpoint(): + sleep(60) + + load_checkpoint() + + # Do a sweep over the entire validation set. + eval_input_iterator.initialize() while True: - while not has_new_checkpoint(): - sleep(60) - - load_checkpoint() - - # Do a sweep over the entire validation set. - session.run(eval_input_iterator.initialize()) - while True: - try: - top1_acc = session.run(mean_top1_acc) - ... - except tf.errors.OutOfRangeError: - break - session.run(strategy.finalize()) + try: + top1_acc = eval() + ... + except tf.errors.OutOfRangeError: + break +strategy.finalize() ``` #### Sharded Input Pipeline @@ -801,42 +809,43 @@ with strategy.scope(): discriminator = GoodfellowDiscriminator(DefaultDiscriminator2D()) generator = DefaultGenerator2D() gan = GAN(discriminator, generator) - disc_optimizer = tf.train.AdamOptimizer(disc_learning_rate, beta1=0.5, beta2=0.9) - gen_optimizer = tf.train.AdamOptimizer(gen_learning_rate, beta1=0.5, beta2=0.9) + disc_optimizer = tf.keras.optimizers.Adam(disc_learning_rate) + gen_optimizer = tf.keras.optimizers.Adam(gen_learning_rate) def discriminator_step(inputs): image, noise = inputs - gan_output = gan.connect(image, noise) - disc_loss, disc_vars = gan_output.discriminator_loss_and_vars() - disc_train_op = disc_optimizer.minimize(disc_loss, var_list=disc_vars) - - with tf.control_dependencies([disc_train_op]): - return tf.identity(disc_loss) + + with tf.GradientTape() as tape: + gan_output = gan.connect(image, noise) + disc_loss, disc_vars = gan_output.discriminator_loss_and_vars() + + grads = tape.gradients(disc_loss, disc_vars) + disc_optimizer.apply_gradients(list(zip(grads, disc_vars))) + return disc_loss def generator_step(inputs): image, noise = inputs - gan_output = gan.connect(image, noise) - gen_loss, gen_vars = gan_output.generator_loss_and_vars() - gen_train_op = gen_optimizer.minimize(gen_loss, var_list=gen_vars) - - with tf.control_dependencies([gen_train_op]): - return tf.identity(gen_loss) + + with tf.GradientTape() as tape: + gan_output = gan.connect(image, noise) + gen_loss, gen_vars = gan_output.generator_loss_and_vars() + + grads = tape.gradient(gen_loss, gen_vars) + gen_optimizer.apply_gradients(list(zip(grads, gen_vars))) + return gen_loss input_iterator = strategy.make_input_iterator(input_fn) -per_replica_disc_losses = strategy.run(discriminator_step, input_iterator) -per_replica_gen_losses = strategy.run(generator_step, input_iterator) -mean_disc_loss = strategy.reduce(per_replica_disc_losses) -mean_gen_loss = strategy.reduce(per_replica_gen_losses) - -with tf.Session() as session: - session.run(strategy.initialize()) - session.run(input_iterator.initialize()) - for _ in range(num_train_steps): - for _ in range(num_disc_steps): - disc_loss = session.run(mean_disc_loss) - for _ in range(num_gen_steps): - gen_loss = session.run(mean_gen_loss) - session.run(strategy.finalize()) + +strategy.initialize() +input_iterator.initialize() +for _ in range(num_train_steps): + for _ in range(num_disc_steps): + per_replica_disc_losses = strategy.run(discriminator_step, input_iterator) + mean_disc_loss = strategy.reduce(AggregationType.MEAN, per_replica_disc_losses) + for _ in range(num_gen_steps): + per_replica_gen_losses = strategy.run(generator_step, input_iterator) + mean_gen_loss = strategy.reduce(AggregationType.MEAN, per_replica_gen_losses) +strategy.finalize() ``` ### Reinforcement Learning @@ -846,11 +855,9 @@ This is an example of Reinforcement Learning system, converted to eager style. ```python -tf.enable_eager_execution() - with strategy.scope(): agent = Agent(num_actions, hidden_size, entropy_cost, baseline_cost) - optimizer = tf.train.RMSPropOptimizer(learning_rate) + optimizer = tf.keras.optimizers.RMSprop(learning_rate) # Queues of trajectories from actors. queues = [] @@ -867,9 +874,12 @@ def learner_input(ctx): return dequeue_batch def learner_step(trajectories): - loss = tf.reduce_sum(agent.compute_loss(trajectories)) + with tf.GradientTape() as tape: + loss = tf.reduce_sum(agent.compute_loss(trajectories)) + agent_vars = agent.get_all_variables() - optimizer.minimize(loss, var_list=agent_vars) + grads = tape.gradient(loss, agent_vars) + optimizer.apply_gradients(list(zip(grads, agent_vars))) return loss, agent_vars # Create learner inputs. @@ -893,7 +903,7 @@ strategy.initialize() for _ in range(num_train_steps): per_replica_outputs = strategy.run(learner_step, learner_inputs) per_replica_losses, updated_agent_var_copies = zip(*per_replica_outputs) - mean_loss = strategy.reduce(per_replica_losses) + mean_loss = strategy.reduce(AggregationType.MEAN, per_replica_losses) strategy.finalize() ```