-
-
Notifications
You must be signed in to change notification settings - Fork 656
Description
🚀 Feature
Idea is to make configurable Metric's reduction/gathering ops. By default, we are using our code, but user can globally override those functions. For example, if uses a custom unsupported distributed framework, or deals with asymmetry like here etc
EDIT:
When a metric is implemented methods like reset, update and compute are decorated with reinit__is_reduced and sync_all_reduce.
sync_all_reduce is implemented here:
ignite/ignite/metrics/metric.py
Lines 550 to 594 in 581f5b4
| def sync_all_reduce(*attrs: Any) -> Callable: | |
| """Helper decorator for distributed configuration to collect instance attribute value | |
| across all participating processes and apply the specified reduction operation. | |
| See :doc:`metrics` on how to use it. | |
| Args: | |
| attrs: attribute names of decorated class | |
| .. versionchanged:: 0.4.5 | |
| - Ability to handle different reduction operations (SUM, MAX, MIN, PRODUCT). | |
| """ | |
| def wrapper(func: Callable) -> Callable: | |
| @wraps(func) | |
| def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable: | |
| if not isinstance(self, Metric): | |
| raise RuntimeError( | |
| "Decorator sync_all_reduce should be used on ignite.metric.Metric class methods only" | |
| ) | |
| ws = idist.get_world_size() | |
| if len(attrs) > 0 and not self._is_reduced: | |
| if ws > 1: | |
| for attr in attrs: | |
| op_kwargs = {} | |
| if ":" in attr: | |
| attr, op = attr.split(":") | |
| valid_ops = ["MIN", "MAX", "SUM", "PRODUCT"] | |
| if op not in valid_ops: | |
| raise ValueError(f"Reduction operation is not valid (expected : {valid_ops}, got: {op}") | |
| op_kwargs["op"] = op | |
| t = getattr(self, attr, None) | |
| if t is not None: | |
| t = idist.all_reduce(t, **op_kwargs) | |
| self._is_reduced = True | |
| setattr(self, attr, t) | |
| else: | |
| self._is_reduced = True | |
| return func(self, *args, **kwargs) | |
| return another_wrapper | |
| setattr(wrapper, "_decorated", True) | |
| return wrapper |
where we are using
idist.all_reduce(t, **op_kwargs)So, the issue desctiption says:
Idea is to make configurable Metric's reduction/gathering ops. By default, we are using our code, but user can globally override those functions.
In other words, we would like to be able to call user custom all_reduce instead of idist.all_reduce
A tentative API for this feature
import ignite.distributed as idist
from ignite.metrics import set_all_reduce_fn, reset_all_reduce_fn, get_all_reduce_fn
from ignite.metrics import Accuracy
def my_all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM", **kwargs):
# ... custom implementation
pass
set_all_reduce_fn(my_all_reduce)
assert get_all_reduce_fn() == my_all_reduce
acc = Accuracy()
acc.update(...)
value = acc.compute() # should call my_all_reduce
reset_all_reduce_fn()
assert get_all_reduce_fn() == idist.all_reduce