Last line here crashes:
import pytorch_lightning as pl
class Module(pl.LightningModule):
def forward(self):
return 0
def test_outside():
a = Module()
print(a.module_arguments)
class A:
def test(self):
a = Module()
print(a.module_arguments)
def test2(self):
test_outside()
test_outside() # prints {}
A().test2() # prints {}
A().test() # crashes
For context, this happens when we want to instantiate LightningModules as part of a unit testing functions.