@@ -69,9 +69,10 @@ def registry_generator_fn(layer_instance, args, kwargs):
6969
7070def compute_gradient_norms (
7171 input_model : tf .keras .Model ,
72+ layer_registry : lr .LayerRegistry ,
7273 x_batch : InputTensor ,
7374 y_batch : tf .Tensor ,
74- layer_registry : lr . LayerRegistry ,
75+ weight_batch : Optional [ tf . Tensor ] = None ,
7576 per_example_loss_fn : Optional [LossFn ] = None ,
7677 num_microbatches : Optional [lr .BatchSize ] = None ,
7778 trainable_vars : Optional [List [tf .Variable ]] = None ,
@@ -84,15 +85,16 @@ def compute_gradient_norms(
8485 Args:
8586 input_model: The `tf.keras.Model` from which to obtain the layers from. The
8687 loss of the model *must* be a scalar loss.
88+ layer_registry: A `LayerRegistry` instance containing functions that help
89+ compute gradient norms quickly. See
90+ `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
91+ more details.
8792 x_batch: An `InputTensor` representing a batch of inputs to the model. The
8893 first axis must be the batch dimension.
8994 y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
9095 must be the batch dimension. The number of examples should match the
9196 number of examples in `x_batch`.
92- layer_registry: A `LayerRegistry` instance containing functions that help
93- compute gradient norms quickly. See
94- `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
95- more details.
97+ weight_batch: Optional batch of weights, passed to the loss function.
9698 per_example_loss_fn: takes as input predictions, labels and weights, and
9799 outputs a vector of per-example losses. If None, derived from
98100 `input_model.loss` by disabling its reduction.
@@ -108,8 +110,9 @@ def compute_gradient_norms(
108110 variables are included.
109111
110112 Returns:
111- A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
112- per-example loss function.
113+ A scalar vector, whose i-th entry is the norm of the gradient of the i-th
114+ example loss (when num_microbatches is None) or the norm of the gradient of
115+ the i-th microbatch loss (define as a mean over the microbatch).
113116 """
114117 tape = tf .GradientTape (persistent = True , watch_accessed_variables = False )
115118 registry_generator_fn = get_registry_generator_fn (
@@ -127,7 +130,7 @@ def compute_gradient_norms(
127130 loss_config = input_model .loss .get_config ()
128131 loss_config ['reduction' ] = tf .keras .losses .Reduction .NONE
129132 per_example_loss_fn = input_model .loss .from_config (loss_config )
130- losses = per_example_loss_fn (y_batch , model_outputs )
133+ losses = per_example_loss_fn (y_batch , model_outputs , weight_batch )
131134 if losses .shape is None :
132135 raise NotImplementedError (
133136 "The unreduced (or per-example) loss's shape cannot be `None`"
@@ -140,7 +143,7 @@ def compute_gradient_norms(
140143 )
141144 if num_microbatches is not None :
142145 losses = tf .reduce_mean (
143- lr .add_microbatch_axis (losses , num_microbatches ), axis = 1
146+ lr .maybe_add_microbatch_axis (losses , num_microbatches ), axis = 1
144147 )
145148 summed_loss = tf .reduce_sum (losses )
146149 # Unwrap the generator outputs so that the next loop avoids duplicating
@@ -165,8 +168,10 @@ def compute_gradient_norms(
165168 vars_list ,
166169 unconnected_gradients = tf .UnconnectedGradients .ZERO ,
167170 )
171+ if not grads_list :
172+ raise ValueError ('Empty gradient list.' )
168173 sqr_norm_list = []
169- for grads , f in zip (grads_list , sqr_norm_fns_list ):
174+ for grads , f in zip (grads_list , sqr_norm_fns_list , strict = True ):
170175 sqr_norm_list .append (f (grads ))
171176 sqr_norm_tsr = tf .stack (sqr_norm_list , axis = 1 )
172177 return tf .sqrt (tf .reduce_sum (sqr_norm_tsr , axis = 1 ))
@@ -199,10 +204,11 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
199204
200205def compute_clipped_gradients_and_outputs (
201206 input_model : tf .keras .Model ,
202- x_batch : InputTensor ,
203- y_batch : tf .Tensor ,
204207 l2_norm_clip : float ,
205208 layer_registry : lr .LayerRegistry ,
209+ x_batch : InputTensor ,
210+ y_batch : tf .Tensor ,
211+ weight_batch : Optional [tf .Tensor ] = None ,
206212 num_microbatches : Optional [lr .BatchSize ] = None ,
207213 clipping_loss : Optional [LossFn ] = None ,
208214) -> Tuple [List [tf .Tensor ], tf .Tensor , tf .Tensor ]:
@@ -218,11 +224,6 @@ def compute_clipped_gradients_and_outputs(
218224
219225 Args:
220226 input_model: The `tf.keras.Model` from which to obtain the layers from.
221- x_batch: An `InputTensor` representing a batch of inputs to the model. The
222- first axis must be the batch dimension.
223- y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
224- must be the batch dimension. The number of examples should match the
225- number of examples in `x_batch`.
226227 l2_norm_clip: A `float` indicating the norm to which per-example gradients
227228 will be clipped. That is, all gradients of the per-example loss functions
228229 will have norm at most `l2_norm_clip`.
@@ -232,6 +233,15 @@ def compute_clipped_gradients_and_outputs(
232233 `output` is the pre-activator tensor, `sqr_grad_norms` is related to the
233234 squared norms of a layer's pre-activation tensor, and `vars` are relevant
234235 trainable weights (see `layer_registry_factories.py` for examples).
236+ x_batch: An `InputTensor` representing a batch of inputs to the model. The
237+ first axis must be the batch dimension.
238+ y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
239+ must be the batch dimension. The number of examples should match the
240+ number of examples in `x_batch`.
241+ weight_batch: Optional vector of weights, passed to the loss function. Must
242+ be of size [batch_size]. In case of microbatching, this will be reshaped
243+ to [num_microbatches, batch_size/num_microbatches] before passing it to
244+ the loss.
235245 num_microbatches: An optional number or scalar `tf.Tensor` for the number of
236246 microbatches. If not None, indicates that the loss is grouped into
237247 num_microbatches (in this case, the batch dimension needs to be a multiple
@@ -243,11 +253,10 @@ def compute_clipped_gradients_and_outputs(
243253 the value of the clipped loss does not reflect the true loss.
244254
245255 Returns:
246- A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the
247- clipped gradient of the loss function, the second is the result of
248- applying `input_model` to `x_batch`, and the third is loss value of
249- `input_model`, weighted by the loss weights generated by a specific
250- `compute_clip_weights()` call.
256+ clipped_grad: the clipped gradient of the loss function
257+ y_pred: the result of applying `input_model` to `x_batch`
258+ clipping_loss_value: the loss value weighted in such a way that its gradient
259+ is `clipped_grad`.
251260 """
252261 if input_model .loss .reduction == 'none' :
253262 raise NotImplementedError (
@@ -258,13 +267,25 @@ def compute_clipped_gradients_and_outputs(
258267 clipping_loss = input_model .compiled_loss
259268 gradient_norms = compute_gradient_norms (
260269 input_model ,
270+ layer_registry ,
261271 x_batch ,
262272 y_batch ,
263- layer_registry ,
273+ weight_batch ,
264274 num_microbatches = num_microbatches ,
265275 trainable_vars = input_model .trainable_variables ,
266276 )
267- loss_weights = compute_clip_weights (l2_norm_clip , gradient_norms )
277+ clip_weights = compute_clip_weights (l2_norm_clip , gradient_norms )
278+ if weight_batch is not None :
279+ if num_microbatches is None :
280+ clip_weights = clip_weights * weight_batch # shape [num_microbatches]
281+ else :
282+ # In this case, weight_batch is of shape [batch_size], we first reshape to
283+ # [num_microbatches, microbatch_size] then multiply by the clip_weights
284+ # (which is of shape [num_microbatches])
285+ weight_batch = lr .maybe_add_microbatch_axis (
286+ weight_batch , num_microbatches
287+ )
288+ clip_weights = clip_weights [:, tf .newaxis ] * weight_batch
268289 with tf .GradientTape () as tape :
269290 # WARNING: When num_microbatches is not None, we need to be sure that
270291 # `compute_loss` always computes the mean over the microbatches
@@ -274,17 +295,9 @@ def compute_clipped_gradients_and_outputs(
274295 # is not defined in the contract so may not hold, especially for
275296 # custom losses.
276297 y_pred = input_model (x_batch , training = True )
277- loss_y_batch = (
278- y_batch
279- if num_microbatches is None
280- else lr .add_microbatch_axis (y_batch , num_microbatches )
281- )
282- loss_y_pred = (
283- y_pred
284- if num_microbatches is None
285- else lr .add_microbatch_axis (y_pred , num_microbatches )
286- )
287- clipping_loss_value = clipping_loss (loss_y_batch , loss_y_pred , loss_weights )
298+ mb_y_batch = lr .maybe_add_microbatch_axis (y_batch , num_microbatches )
299+ mb_y_pred = lr .maybe_add_microbatch_axis (y_pred , num_microbatches )
300+ clipping_loss_value = clipping_loss (mb_y_batch , mb_y_pred , clip_weights )
288301 clipped_grads = tape .gradient (
289302 clipping_loss_value ,
290303 input_model .trainable_variables ,
0 commit comments