diff --git a/rfcs/20200411-fuse_recv.md b/rfcs/20200411-fuse_recv.md new file mode 100644 index 000000000..ec73aefe8 --- /dev/null +++ b/rfcs/20200411-fuse_recv.md @@ -0,0 +1,232 @@ +# FuseRecv + +| Status | Proposed | +:-------------- |:---------------------------------------------------- | +| **Author(s)** | Tongxuan Liu(tongxuan.ltx@alibaba-inc.com) Peng Tao(jiankeng.pt@alibaba-inc.com) Langshi Chen (langshi.cls@alibaba-inc.com) | +| **Reviewers(s)** | Ayush Dubey(ayushd@google.com) Jeroen Bédorf(jeroen@minds.ai) Derek Murray(mrry@google.com) Bairen Yi(yibairen.byron@bytedance.com) Paul Tucker(paul.tucker@gmail.com) | +| **Sponsor** | Ayush Dubey(ayushd@google.com) | +| **Updated** | 2020-04-11 | + +## Objective +This RFC proposes a new FuseRecv Op which would receive multiple tensors with +different types through one Remote Procedure Call (RPC). This feature could +significantly reduce the number of RPC calls in most rank or match models +such as Search, Recommend or Ad systems. + +## Motivation +When very many small tensors are being transferred around the same time, +it's more efficient to transfer multiple tensors in a single RPC rather than +using a separate RPC for each of them. + +In the case the neural network graph is complicated, each iteration through +the graph may introduce tens or even hundreds of RPC operations between the running +nodes. In general, there are a large number of small tensors, such as multiple +feature columns that gather data from the same Parameter Server. These tensors +have no dependence on each other, and each feature column results in at least +one RPC operation in the forward stage. In CTR (Click Through Rate) models or +models that are mostly sparse (such as Match or Rank models that are widely +used in Recommender and Ad systems), there would be hundreds of feature columns. +In our scenario, each sample includes at least hundreds of features. +One training job normally uses thousands of workers and tens of parameter servers. +One worker generally has to get variables from all the parameter servers, and each +feature column, at least in the forward stage, receives at least one request from +the parameter server. There could be hundreds of RPC operations for these feature columns, +and even more for some of the big feature columns (such as ids). These would be partitioned +into dozens of RPCs per feature column. In summary there would be +at least hundreds of RPC per worker for these feature columns only, and +hundreds of thousands of RPCs per step, for each parameter server in the forward stage. +Most feature columns only gather very small tensors from the parameter +server, usually less than 100KB. Logically these small tensors could be +sent together (e.g. fused). Furthermore, tensors that belong to the same layer can also +be fused before transfer, which would significantly reduce the number of RPC operations. + +As we know, each RPC operations introduces some satellite overhead besides the +actual tensor data transfer, which includes: +* Serialization/Deserialization which introduces additional overhead for each RPC operation. +* The execution engine overhead for executing a Recv node operation, and the corresponding thread pool + action required to execute the RPC callback function. + +## User Benefit + +Performance improvement: From performance benchmarking of the feature during large +(end-user) training jobs (> 400 workers), we normally see that the training speed would +be 1.5-2x timer faster in the parameter-server/worker setup. + +## Design Proposal + +![Figure 1: Current graph partition strategy](20200411-fuse_recv/current_graph_partition_strategy.png "Current graph partition strategy") +![Figure 2: Graph partition strategy with FuseRecv](20200411-fuse_recv/graph_partition_strategy_with_fuse_recv.png "Graph partition strategy with FuseRecv") + +In the original Recv/Send design, each Recv node only receives one tensor +even if there are Recv Ops that output to the same destination Op. Moreover each +Recv node would trigger one RPC operation even if the received tensor is a scalar. + +In the proposed design, we traverse (partitioned) graphs according to +its topology and iteratively replace Recv nodes with the new FuseRecv nodes. +Please refer to the details in Section [FuseRecv Optimizer in Grappler](#FuseRecv Optimizer in Grappler) + +As shown in Figures 1 and 2, instead of adding a Recv node for each tensor +‘a’ and ‘x’, we use only one FuseRecv node to replace the two Recv nodes which +fetches two tensors together. The FuseRecv node will have two output +‘slots’ (‘ports’): slot 0 feeds input ‘b’ and ‘c’ and slot 1 feeds ‘y’. +Notice that, because the RPC operation is Recv driven, there is no need +to fuse the send node. + +A new RPC method ‘FuseRecvTensorAsync’ and its Handler (FuseRecvTensorHandlerRaw) +is added into WorkInterface and WorkerService. FuseRecvTensor follows similar +optimization steps as RecvTensor to avoid copying the response buffer. + +### Alternatives Considered +#### Fuse the tensors into a single Send/Recv Solution 1 (Derek Murray) +Pack the N tensors to be sent into a length-N DT_VARIANT vector. + +Pros: Reuse currently RPC, avoid potential intricate changes in zero-copy +response buffer code. + +Cons: Introduce memcopy overhead. + +#### Fuse the tensors into a single Send/Recv Solution 2 (Derek Murray) +Pack the tensor contents into a single flattened buffer. This would be very +similar to the ScopedAllocator optimization that +ayushd@google.com and ++tucker@google.com implemented for collectives, and it might be possible +to reuse some of the graph analysis code + +Pros: Reuse currently RPC, avoid potential intricate changes in zero-copy +response buffer code. + +Cons: The fused tensors could be of different types and dynamic shapes, +which couldn't be handled by this solution. + +#### Dynamic Fusion in runtime (Paul Tucker) +Instead of adding a new FuseRecvTensor method to the Worker interface, +we add a slightly different RecvSomeTensors method. The client sends a +list of keys for which it's ready to receive values to the server and the +server streams back one or more when it's ready. It's the responsibility of +the client to retry any key that was not included in the response. + +To make this work well there needs to be some dynamic bundling on each side. +For example, on the client side a call to RecvTensor on the local Rendezvous +for a remote value does not necessarily result in an immediate RPC. It might +if the value is expected to be large, but it might also just add the key to +a ready set associated with the remote host. An RPC may not be sent until +the ready set reaches a certain size, or a minimum time has elapsed since the +last RPC against that host was started. When the response is received any +missing keys go back in the ready set. + +On the server side there could be some logic to decide for a RecvSomeTensors +method whether to wait for more of the requested values to be ready or just +immediately send what's available now and let the client re-request anything +missing. + +Pros: Dynamic fusion in runtime seems get better result, and also brings +ability to control priority of tensors (which Recv is more important). + +Cons: Potential bottleneck of the solution is the time window of ready set. +For different models it would be much different, manually setting the value +would be hard. This solution is another good candidate of FuseRecv. + +### Performance Implications +With a wide and deep model, the number of RPCs calls per step has been reduced +by 55%, and the overall training throughput has increased by 40%. +![Figure 3: performance_result](20200411-fuse_recv/performance_result.png "Performance result") + +### Dependencies +* None + +### Engineering Impact +* Engineering impact: Once the feature is (manually) enabled (in ConfigProto.GraphOptions.do_fuse_recv), the test times would be longer because the FuseRecv post-partitioned optimizer would traverse and update the graph. +* Maintenance: Minimal maintenance overhead. The TensorFlow team and contributors will maintain the documentation and keep it up to date. Changes should be reviewed and approved by the TensorFlow team leads. + +### Platforms and Environments +* Platforms: The feature is independent of platforms. +* Execution environments (Cloud services, accelerator hardware): The first stage would support CPU & GPU device. We consider supporting +additional devices as much as possible. + +### Best Practices +* We strongly suggest to enable FuseRecv in rank or match models such as [W&DL](https://arxiv.org/abs/1606.07792), [Dien](https://arxiv.org/abs/1809.03672). + +### Tutorials and Examples +Example of how to enable the FuseRecv feature: + +``` + >>> tf.config.optimizer.set_experimental_options({"do_fuse_recv": True}) +``` + +### Compatibility +* This feature works with the ParameterServerStrategy. +* This feature considers tensors on difference devices such as CPU, GPU and TPU. +* Independent of SavedModel or checkpoint. + +### User Impact +* None + +## Detailed Design + +### FuseRecv Op +We introduce the _RecvV2 Op and an RPC operation named FuseRecvTensorAsync in +RemoteWorker and WorkerService. The _RecvV2 Op definition is as follows: + +``` + >>> REGISTER_OP("_RecvV2") + >>> .Output("tensor: tensor_type") + >>> .Attr("tensor_type: list(type)") + >>> .Attr("tensor_name: list(string)") + >>> .Attr("send_device: string") + >>> .Attr("send_device_incarnation: int") + >>> .Attr("recv_device: string") + >>> .Attr("client_terminated: bool = false") + >>> .SetIsStateful() + >>> .SetShapeFn(shape_inference::UnknownShape); +``` + +FuseRecv requests a list of tensors with different types from remote devices, generally +we only fuse the Recv ops in the same recv device and on the same send device. + +### FuseRecv Optimizer in Grappler +During the post partition phase, we add a new pass to the post-partitioning optimizer +called “FuseRecv” to fuse Recv ops together. We traverse partitioned graphs & +the whole graph, replace Recv ops by FuseRecv ops in the partitioned graphs according +to its topology while iteratively searching and fusing potential Recv +operations. See Figure 4 for the formal algorithm definition. + +![Figure 4: fuse_recv_procedure](20200411-fuse_recv/fuse_recv_procedure.png "Fuse Recv Procedure") + +The procedure RECVFUSE takes two input arguments: 1) the TF computation +graph g, 2) a Partitioned graph. It is worth noting that the iteration of +all nodes shall start from the `root` nodes, which do not have any +source edge (node). The process between line 17 and 37 would be iteratively +executed and output key-value pairs (value: a group of edges could be fused +into one FuseRecv node). Then based on the grouped edges, we find out Recv +nodes in partitioned graph which could be replace by FusedRecv nodes. Besides +RECVFUSE also makes sure that no deadlock exists after the change to the +original graph. Also, the RPC operation of FuseRecvTensor is able to overlap +the computation and communication by using the graph topology. + +### FuseRecv RPC Method and Handler +A new RPC method ‘FuseRecvTensorAsync’ is added to the WorkerInterface. +We extend the ‘FuseRecvTensorAsync’ method with the ability to handle +multi rendezvous keys and fetch multi key tensors. + +At the server side, we add a ‘FuseRecvTensorHandlerRaw’, which handles +the multi rendezvous key for the ‘local recv’ instantiated by the local +tensor operations. As mentioned before, the sending nodes are not fused +and we therefore must do multiple local recvs corresponding to the +multi send nodes. + +Because the ‘FuseRecvTensorAsync’ handler might be executed before +the send operations happen, a call back wrapper is required. We use +a counter, initialized with the fuse count, and each send action triggers +the call back wrapper and performs an atomic decrease of the counter, +when the counter reaches 0, the real callback is executed and the tensors +are sent to the Recv node. + +### Dead Tensor Handling +We treat the output of the FuseRecv node as dead if and only if all the +fused tensors are dead. + +### FuseRecv Error Handling +The status of the FuseRecv node would be similar as the Recv node, which +include additional information for every Recv tensor. + +## Questions and Discussion Topics + diff --git a/rfcs/20200411-fuse_recv/current_graph_partition_strategy.png b/rfcs/20200411-fuse_recv/current_graph_partition_strategy.png new file mode 100644 index 000000000..1d882cd96 Binary files /dev/null and b/rfcs/20200411-fuse_recv/current_graph_partition_strategy.png differ diff --git a/rfcs/20200411-fuse_recv/fuse_recv_procedure.png b/rfcs/20200411-fuse_recv/fuse_recv_procedure.png new file mode 100644 index 000000000..359b006ee Binary files /dev/null and b/rfcs/20200411-fuse_recv/fuse_recv_procedure.png differ diff --git a/rfcs/20200411-fuse_recv/graph_partition_strategy_with_fuse_recv.png b/rfcs/20200411-fuse_recv/graph_partition_strategy_with_fuse_recv.png new file mode 100644 index 000000000..58c7887eb Binary files /dev/null and b/rfcs/20200411-fuse_recv/graph_partition_strategy_with_fuse_recv.png differ diff --git a/rfcs/20200411-fuse_recv/performance_result.png b/rfcs/20200411-fuse_recv/performance_result.png new file mode 100644 index 000000000..54100ba5a Binary files /dev/null and b/rfcs/20200411-fuse_recv/performance_result.png differ diff --git a/rfcs/20200420-tfx-tuner-component.md b/rfcs/20200420-tfx-tuner-component.md new file mode 100644 index 000000000..57f550f35 --- /dev/null +++ b/rfcs/20200420-tfx-tuner-component.md @@ -0,0 +1,383 @@ +# TFX Tuner Component + +| Status | Proposed | +| :------------ | :-------------------------------------------------------- | +| **Author(s)** | Jiayi Zhao (jyzhao@google.com), Amy Wu (wuamy@google.com) | +| **Sponsor** | Zhitao Li (zhitaoli@google.com), Tom O'Malley (omalleyt@google.com), Matthieu Monsch (mtth@google.com), Makoto Uchida (muchida@google.com), Goutham Bhat (goutham@google.com) | +| **Updated** | 2020-04-20 | + +## Objective + +### Goal + +* A new Tuner component in TFX for automated hyper-parameter tuning, which is + based on abstractions from + [KerasTuner library](https://github.com/keras-team/keras-tuner), in order to + reuse abstractions and algorithms from latter. + +### Non Goal + +* Natively support multi-worker tuning by the system. As TFX doesn't have + ability to manage multi-worker clusters, running multiple trials in parallel + (parallel tuning) and running each trial in distributed env (distributed + training) are not supported natively. Parallel tuning may instead be + realized by a particular implementation of TFX Tuner (custom Executor), + e.g., in Google Cloud environment. +* Implementation of custom tuner for + [KerasTuner library](https://github.com/keras-team/keras-tuner) is out of + scope of this design discussion, e.g., a built-in EstimatorTuner support. + However, user project can still implement a tuner that inherits from + [`kerastuner.BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py) + and provide it to the proposed TFX Tuner component. + +## Background and Motivation + +A hyperparameter is a parameter whose value is used to control the learning +process of a model or the model itself (e.g., layers and number of nodes). By +contrast, the values of other parameters (typically node weights) are learned. + +Hyperparameter optimization is a critical part of many machine learning +pipelines. Thus we propose a new TFX component, with the given search space +which specifies the hyperparameter configuration (name, type, range etc.). TFX +will optimize the hyperparameters based on the tuning algorithm. + +## User Benefit + +This document proposes a built-in TFX Tuner component, which works seamlessly +with Trainer and other TFX components. As the Tuner component will utilize the +[KerasTuner library](https://github.com/keras-team/keras-tuner), all supported +tuning methods will be available to TFX, including custom implementation of +KerasTuner. + +## Design Proposal + +TFX Tuner component will be built with the +[KerasTuner library](https://github.com/keras-team/keras-tuner). In the +following sections, we will first briefly go over the KerasTuner library and +several concepts in hyperparameter optimization. Then we will focus on our Tuner +component interface and how we utilize the KerasTuner library. After that, we +will discuss parallel tuning and our plan on Google Cloud integration. + +### KerasTuner Library + +The following graph shows a typical workflow of hyperparameter tuning under the +KerasTuner framework: + +
+ +Given the user provided model which accepts a hyperparameter container, tuner +can search optimization through trials created by the tuning algortihm. For each +trial, values within search spaces will be assigned to hyperparameter +containers, and the user model will be trained with these hyperparameter values +and evaluated based on the objective provided to the tuner. The evaluation +results will be reported back to tuner and the tuning algorithm will decide the +hyperparameter values for the next trial. After reaching certain conditions, +e.g., max trials, the tuner will stop iteration and return the optimal +hyperparameters. + +KerasTuner library provides above tuning functionality, here are some +abstractions in KerasTuner: + +* [`HyperParameters`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/hyperparameters.py): + Hyperparameter container for both search space, and current values. +* [`Oracle`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/oracle.py): + Implementation of a hyperparameter tuning algorithm, e.g., random search, + including state management of the algorithm’s progress. +* [`Trial`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/trial.py): + Provided by the Oracle, contains information about Hyperparameter values for + the current iteration. +* [`BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py): + a base tuner interface for above tuning workflow, responsible for the + iteration of trial execution: + * Generates Trial using Oracle. + * Trains user model with the HyperParameters in the current Trial. + * Evaluates metrics and reports back to Oracle for next Trial. +* [`Tuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py): + An implementation of BaseTuner, for Keras model tuning. + +Note: Other than the Tuner, abstractions defined by `HyperParameters`, `Oracle`, +`Trial` and `BaseTuner` are not restricted to Keras models, although the library +is called KerasTuner. + +For more details and code examples, please refer to +[here](https://github.com/keras-team/keras-tuner). + +### Component Interface + +Tuner component takes raw or transformed examples as input, along with schema or +transform_graph for the feature specification, and outputs the hyperparameter +tuning results, below shows the specification of Tuner component: + +```python +class TunerSpec(ComponentSpec): + """ComponentSpec for TFX Tuner Component.""" + + PARAMETERS = { + # Specify a python module file which contains a UDF `tuner_fn`. + 'module_file': ExecutionParameter(type=(str, Text), optional=True), + # Specify the steps for the training stage of each trial’s execution. + 'train_args': ExecutionParameter(type=trainer_pb2.TrainArgs), + 'eval_args': ExecutionParameter(type=trainer_pb2.EvalArgs), + } + + INPUTS = { + 'examples': ChannelParameter(type=standard_artifacts.Examples), + 'schema': ChannelParameter( + type=standard_artifacts.Schema, optional=True), + 'transform_graph': + ChannelParameter( + type=standard_artifacts.TransformGraph, optional=True), + } + + OUTPUTS = { + 'best_hyperparameters': + ChannelParameter(type=standard_artifacts.HyperParameters), + } +``` + +Trainer has an optional hyperparameters input; tuning result can be fed into it +so that Trainer can utilize best hyperparameters to construct the model. Below +shows an example about how tuner and trainer are chained in the pipeline: + +```python +# TrainerSpec: + INPUTS = { + ... + 'hyperparameters': + ChannelParameter( + type=standard_artifacts.HyperParameters, optional=True), + } + +# Pipeline DSL Example: + tuner = Tuner( + examples=example_gen.outputs['examples'], + schema=schema_gen.outputs['schema'], + module_file=module_file, + train_args=trainer_pb2.TrainArgs(num_steps=1000), + eval_args=trainer_pb2.EvalArgs(num_steps=500)) + + trainer = Trainer( + module_file=module_file, + examples=example_gen.outputs['examples'], + schema=schema_gen.outputs['schema'], + hyperparameters=tuner.outputs['best_hyperparameters'], + train_args=trainer_pb2.TrainArgs(num_steps=10000), + eval_args=trainer_pb2.EvalArgs(num_steps=5000)) +``` + +For Trainer, users need to define model code and training logic +([Generic Trainer](https://github.com/tensorflow/tfx/blob/r0.21.2/docs/guide/trainer.md#generic-trainer)) +in the module_file. For Tuner, in addition to model code, users also need to +define hyperparameters, search space and a tuning algorithm in the module_file. +A `tuner_fn` with the following signature is required for Tuner: + +```python +from kerastuner.engine import base_tuner +import tensorflow as tf +from tfx.components.trainer.executor import TrainerFnArgs + +# Current TrainerFnArgs will be renamed to FnArgs as a util class. +FnArgs = TrainerFnArgs +TunerFnResult = NamedTuple('TunerFnResult', + [('tuner', base_tuner.BaseTuner), + ('fit_kwargs', Dict[Text, Any])]) + +def tuner_fn(fn_args: FnArgs) -> TunerFnResult: + """Build the tuner using the KerasTuner API. + + Args: + fn_args: Holds args as name/value pairs. + working_dir: working dir for tuning. Automatically set by Executor. + train_files: List of file paths containing training tf.Example data. + eval_files: List of file paths containing eval tf.Example data. + train_steps: number of train steps. + eval_steps: number of eval steps. + schema: optional schema file of the input data. + transform_graph: optional transform graph produced by TFT. + + Returns: + A namedtuple contains the following: + - tuner: A BaseTuner that will be used for tuning. + - fit_kwargs: Args to pass to tuner’s run_trial function for fitting the + model , e.g., the training and validation dataset. Required + args depend on the above tuner’s implementation. + """ +``` + +The TunerFnResult returned by the above tuner_fn contains an instance that +implements the +[`BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py) +interface, that’s the contract required by Tuner for tuning. The model code, +hyperparameters, search space and tuning algorithm are hidden under the +BaseTuner abstraction so the Tuner itself is generic and agnostic to the model +framework and tuning logic. Below shows an example module file with Keras model: + +```python +import kerastuner +import tensorflow as tf +... + +def _input_fn(file_pattern: Text, ...) -> tf.data.Dataset: + ... + +# Model code for Trainer and Tuner. +def _build_keras_model(hp: kerastuner.HyperParameters) -> tf.keras.Model: + ... + for _ in range(hp.get('num_layers')): + ... + ... + model = tf.keras.Model(...) + model.compile( + optimizer=tf.keras.optimizers.Adam(hp.get('learning_rate')), + loss='sparse_categorical_crossentropy', + metrics=[tf.keras.metrics.Accuracy()]) + return model + +# This will be called by TFX Tuner. +def tuner_fn(fn_args: FnArgs) -> TunerFnResult: + hp = kerastuner.HyperParameters() + # Defines search space. + hp.Choice('learning_rate', [1e-1, 1e-3]) + hp.Int('num_layers', 1, 5) + + # RandomSearch is a subclass of Keras model Tuner. + tuner = kerastuner.RandomSearch( + _build_keras_model, + max_trials=5, + hyperparameters=hp, + allow_new_entries=False, + objective='val_accuracy', + directory=fn_args.working_dir, + project_name='test') + + train_dataset=_input_fn(fn_args.train_files, ...) + eval_dataset=_input_fn(fn_args.eval_files, ...) + + return TunerFnResult( + tuner=tuner, + fit_kwargs={'x': train_dataset, + 'validation_data': eval_dataset, + 'steps_per_epoch': fn_args.train_steps, + 'validation_steps': fn_args.eval_steps}) + +# This will be called by TFX Generic Trainer. +def run_fn(fn_args: FnArgs) -> None: + hp = kerastuner.HyperParameters.from_config(fn_args.hyperparameters) + model = _build_keras_model(hp) + model.fit(...) + model.save(...) +``` + +In Tuner’s executor, `tuner_fn` will be called with information resolved from +component inputs, then we call the `search` function of the returned tuner with +`fit_kwargs` to launch trials for tuning, and finally emit the best trial’s +hyperparameters: + +```python +# Executor of Tuner Component: +class Executor(base_executor.BaseExecutor): + + def Do(self, + input_dict: Dict[Text, List[types.Artifact]], + output_dict: Dict[Text, List[types.Artifact]], + exec_properties: Dict[Text, Any]) -> None: + ... + tuner_spec = tuner_fn(self._create_fn_args(input_dict, exec_properties)) + tuner_spec.tuner.search(**tuner_spec.fit_kwargs) + # Output file contains json format string of hyperparameters.get_config(). + self._emit_best_hyperparameters( + output_dict, tuner_spec.tuner.get_best_hyperparameters()[0]) +``` + +### Parallel Tuning + +In parallel tuning, multiple trials are executed in parallel. In this section, +we will discuss how distribution works for KerasTuner library and the status of +TFX. + +In the `search` function of tuner, trials will be run in sequence instead of in +parallel. To support parallel tuning, we need to launch multiple tuners (the +tuner here refers to the one in KerasTuner library, not TFX Tuner component), +and have an optimization service for managing the state of the tuning algorithm, +with which oracle of each tuner communicates, and retrieves the trials for each +tuner. + +
+ +The above graph shows a parallel tuning of three tuners. Each tuner runs as a +different worker, and it retrieves trials from its own oracle, which talks to +optimization service. Trials of different tuners can run in parallel but trials +within the same tuner will still execute in sequence. When launching tuners, the +same identifier will be assigned to each oracle, thus the optimization service +knows they are in the same tuning job group and will assign hyperparameter +values for their trials based on the algorithm. + +The number of parallel tuners can be passed to component by the `TuneArgs` as +shown below: + +```python +# Args specific to tuning. +message TuneArgs { + # Number of trials to run in parallel. + # Each trial will be trained and evaluated by separate worker jobs. + int32 num_parallel_trials = 1; +} + +class TunerSpec(ComponentSpec): + + PARAMETERS = { + ... + 'tune_args': ExecutionParameter(type=tuner_pb2.TuneArgs), + } +``` + +The KerasTuner library allows users to config +[`tf.distribute.Strategy`](https://www.tensorflow.org/tutorials/distribute/kerass) +if they are using +[`kerastuner.Tuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py) +class (or subclasses of it). In above parallel tuning, each trial (each model +training) is executed in a single worker, as such only single machine strategy +is allowed. To support multi-worker distributed training, we need to be able to +execute the trial (training) on different workers. + +At the time of writing, KerasTuner library can be used for parallel tuning with +single machine `tf.distribute.Strategy`, e.g., +[`MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy) +, multi-worker strategy (distributed training for trial) support is on the +roadmap (note that cluster managing is not part of the library). + +At the time of writing, TFX doesn’t have the ability to manage the multi-worker +cluster and the centralized optimization service, so parallel tuning or +distributed training is not supported natively in TFX (local or on-prem), but in +the next section, we will discuss the integration for Google Cloud. Similar +parallel tuning support can be built for other execution environments. + +### Google Cloud Integration + +In this section, we discuss the Tuner component with +[Google Cloud AI Platform](https://cloud.google.com/ai-platform) (CAIP), +specifically, an implementation of KerasTuner Oracle that talks to the +[AI Platform Optimizer](https://cloud.google.com/ai-platform/optimizer/docs/overview) +as the centralized optimization service, and a custom Tuner executor +implementation that makes use of the Cloud Optimizer-based Oracle (symbol names +are subject to change). + +As mentioned above in the parallel tuning section, KerasTuner uses a centralized +optimization service that manages states of a tuning study and trials. In +addition to that, we will create a `CloudOracle` as a client to the AI Platform +Optimizer service, and a `CloudTuner` which inherits from Keras +[Tuner](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py). +In the module file, users create the `tuner_fn` with `CloudTuner`, and then +users configure the TFX Tuner component to use the a custom Tuner executor +(`CloudExecutor`), which launches multiple `CloudTuner`s on a Google Cloud AI +Platform Training job with possibly multiple worker machines running various +trials concurrently. Below shows the workflow for in process tuning and Cloud +tuning. + +
+ +## Future work + +* Native support for multi-worker parallel tuning. +* Custom Tuner (inherits from BaseTuner) examples, e.g., for Estimator support + or Keras custom training loop support. diff --git a/rfcs/20200420-tfx-tuner-component/cloud.png b/rfcs/20200420-tfx-tuner-component/cloud.png new file mode 100644 index 000000000..09559da71 Binary files /dev/null and b/rfcs/20200420-tfx-tuner-component/cloud.png differ diff --git a/rfcs/20200420-tfx-tuner-component/parallel_tuning.png b/rfcs/20200420-tfx-tuner-component/parallel_tuning.png new file mode 100644 index 000000000..efd62b113 Binary files /dev/null and b/rfcs/20200420-tfx-tuner-component/parallel_tuning.png differ diff --git a/rfcs/20200420-tfx-tuner-component/workflow.png b/rfcs/20200420-tfx-tuner-component/workflow.png new file mode 100644 index 000000000..4f8bd89da Binary files /dev/null and b/rfcs/20200420-tfx-tuner-component/workflow.png differ