Skip to content

Commit 8135613

Browse files
Update pytorch_lightning/utilities/params_tying.py
Co-authored-by: ananthsub <[email protected]>
1 parent 381e526 commit 8135613

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytorch_lightning/utilities/params_tying.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,22 @@ def find_shared_parameters(
4848
return [x for x in tied_parameters.values() if len(x) > 1]
4949

5050

51-
def set_shared_parameters(module: "pl.LightningModule", shared_params: list):
51+
def set_shared_parameters(module: nn.Module, shared_params: list):
5252
for shared_param in shared_params:
5353
ref = _get_module_by_path(module, shared_param[0])
5454
for path in shared_param[1:]:
5555
_set_module_by_path(module, path, ref)
5656
return module
5757

5858

59-
def _get_module_by_path(module: "pl.LightningModule", path: str):
59+
def _get_module_by_path(module: nn.Module, path: str):
6060
path = path.split(".")
6161
for name in path:
6262
module = getattr(module, name)
6363
return module
6464

6565

66-
def _set_module_by_path(module: "pl.LightningModule", path: str, value: Parameter):
66+
def _set_module_by_path(module: nn.Module, path: str, value: Parameter):
6767
path = path.split(".")
6868
for name in path[:-1]:
6969
module = getattr(module, name)

0 commit comments

Comments
 (0)