Skip to content

Allow passing in a tensor to tfa.optimizers.MovingAverage num_updates #754

@Hyperparticle

Description

@Hyperparticle

Currently, tfa.optimizers.MovingAverage has an assert that requires num_updates to be of type int, shown here. This prevents me from passing in an integer tensor that changes with the global step, which is officially supported by tf.train.ExponentialMovingAverage.

Can this assert be updated to handle this use case?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions