🐛 Bug
The below case fails for init_meta_context. This probably has to do something with the subclasses for MyModule being changed. This as a result fails when we use an xformers model factory within the context manager.
import torch
from pytorch_lightning.utilities.meta import init_meta_context
class BaseModule(torch.nn.Module):
pass
class MyModule(BaseModule):
pass
with init_meta_context():
my_module = MyModule()
assert isinstance(my_module, BaseModule)
cc @tchaton @blefaudeux