@@ -621,7 +621,7 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
621621 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> skip_ddp_prefix</ span > < span class ="p "> :</ span >
622622 < span class ="n "> fqn_obj_names</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> curr_obj_name</ span > < span class ="p "> )</ span >
623623 < span class ="k "> elif</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> curr_obj</ span > < span class ="p "> ,</ span > < span class ="n "> FSDP</ span > < span class ="p "> ):</ span >
624- < span class ="k "> if</ span > < span class ="n "> obj_names</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="o "> ==</ span > < span class ="n "> FLAT_PARAM</ span > < span class ="p "> :</ span >
624+ < span class ="k "> if</ span > < span class ="n "> i </ span > < span class =" o " > < </ span > < span class =" nb " > len </ span > < span class =" p " > ( </ span > < span class =" n " > obj_names </ span > < span class =" p " > ) </ span > < span class =" o " > - </ span > < span class =" mi " > 1 </ span > < span class =" ow " > and </ span > < span class =" n " > obj_names</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="o "> ==</ span > < span class ="n "> FLAT_PARAM</ span > < span class ="p "> :</ span >
625625 < span class ="n "> prefix</ span > < span class ="o "> =</ span > < span class ="s2 "> "."</ span > < span class ="o "> .</ span > < span class ="n "> join</ span > < span class ="p "> (</ span > < span class ="n "> fqn_obj_names</ span > < span class ="p "> )</ span >
626626 < span class ="n "> flat_param</ span > < span class ="o "> =</ span > < span class ="nb "> getattr</ span > < span class ="p "> (</ span > < span class ="n "> curr_obj</ span > < span class ="p "> ,</ span > < span class ="n "> FLAT_PARAM</ span > < span class ="p "> )</ span >
627627 < span class ="k "> if</ span > < span class ="n "> prefix</ span > < span class ="p "> :</ span >
@@ -660,7 +660,7 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
660660 < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ],</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="n "> Set</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ],</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ]</ span >
661661 < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> {}</ span >
662662 < span class ="n "> all_fqns</ span > < span class ="o "> =</ span > < span class ="nb "> set</ span > < span class ="p "> ()</ span >
663- < span class ="k "> for</ span > < span class ="n "> name</ span > < span class ="p "> ,</ span > < span class ="n "> param</ span > < span class ="ow "> in</ span > < span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> named_parameters</ span > < span class ="p "> ():</ span >
663+ < span class ="k "> for</ span > < span class ="n "> name</ span > < span class ="p "> ,</ span > < span class ="n "> param</ span > < span class ="ow "> in</ span > < span class ="n "> chain </ span > < span class =" p " > ( </ span > < span class =" n " > model</ span > < span class ="o "> .</ span > < span class ="n "> named_parameters</ span > < span class ="p "> (), </ span > < span class =" n " > model </ span > < span class =" o " > . </ span > < span class =" n " > named_buffers </ span > < span class =" p " > () ):</ span >
664664 < span class ="n "> fqns</ span > < span class ="o "> =</ span > < span class ="n "> _get_fqns</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> name</ span > < span class ="p "> )</ span >
665665 < span class ="n "> fqn_param_mapping</ span > < span class ="p "> [</ span > < span class ="n "> param</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> fqns</ span >
666666 < span class ="k "> for</ span > < span class ="n "> fqn</ span > < span class ="ow "> in</ span > < span class ="n "> fqns</ span > < span class ="p "> :</ span >
@@ -859,7 +859,7 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
859859 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> info</ span > < span class ="o "> .</ span > < span class ="n "> handle_model</ span > < span class ="ow "> or</ span > < span class ="ow "> not</ span > < span class ="n "> state_dict</ span > < span class ="p "> :</ span >
860860 < span class ="k "> return</ span > < span class ="n "> _IncompatibleKeys</ span > < span class ="p "> ({},</ span > < span class ="p "> {})</ span >
861861
862- < span class ="k "> for</ span > < span class ="n "> key</ span > < span class ="p "> ,</ span > < span class ="n "> _</ span > < span class ="ow "> in</ span > < span class ="n "> model</ span > < span class ="o "> .</ span > < span class ="n "> named_parameters</ span > < span class ="p "> ():</ span >
862+ < span class ="k "> for</ span > < span class ="n "> key</ span > < span class ="p "> ,</ span > < span class ="n "> _</ span > < span class ="ow "> in</ span > < span class ="n "> chain </ span > < span class =" p " > ( </ span > < span class =" n " > model</ span > < span class ="o "> .</ span > < span class ="n "> named_parameters</ span > < span class ="p "> (), </ span > < span class =" n " > model </ span > < span class =" o " > . </ span > < span class =" n " > named_buffers </ span > < span class =" p " > () ):</ span >
863863 < span class ="n "> fqns</ span > < span class ="o "> =</ span > < span class ="n "> _get_fqns</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> key</ span > < span class ="p "> )</ span >
864864 < span class ="n "> fqns_with_ddp_prefix</ span > < span class ="o "> =</ span > < span class ="n "> _get_fqns</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> ,</ span > < span class ="n "> key</ span > < span class ="p "> ,</ span > < span class ="n "> skip_ddp_prefix</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )</ span >
865865 < span class ="k "> for</ span > < span class ="n "> fqn</ span > < span class ="p "> ,</ span > < span class ="n "> fqn_with_ddp_prefix</ span > < span class ="ow "> in</ span > < span class ="nb "> zip</ span > < span class ="p "> (</ span > < span class ="n "> fqns</ span > < span class ="p "> ,</ span > < span class ="n "> fqns_with_ddp_prefix</ span > < span class ="p "> ):</ span >
@@ -1142,25 +1142,25 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
11421142< span class ="sd "> optimizer parameter IDs to the canonical FQNs.</ span >
11431143
11441144< span class ="sd "> Example:</ span >
1145+ < span class ="sd "> >>> # xdoctest: +SKIP</ span >
1146+ < span class ="sd "> >>> import torch</ span >
1147+ < span class ="sd "> >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP</ span >
1148+ < span class ="sd "> >>> from torch.nn.parallel import DistributedDataParallel as DDP</ span >
1149+ < span class ="sd "> >>> from torch.distributed.checkpoint.state_dict import get_state_dict</ span >
11451150
1146- < span class ="sd "> import torch</ span >
1147- < span class ="sd "> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP</ span >
1148- < span class ="sd "> from torch.nn.parallel import DistributedDataParallel as DDP</ span >
1149- < span class ="sd "> from torch.distributed.checkpoint.state_dict import get_state_dict</ span >
1150-
1151- < span class ="sd "> fsdp_model = FSDP(copy.deepcopy(model))</ span >
1152- < span class ="sd "> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)</ span >
1153- < span class ="sd "> ddp_model = DDP(copy.deepcopy(model))</ span >
1154- < span class ="sd "> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)</ span >
1151+ < span class ="sd "> >>> fsdp_model = FSDP(copy.deepcopy(model))</ span >
1152+ < span class ="sd "> >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)</ span >
1153+ < span class ="sd "> >>> ddp_model = DDP(copy.deepcopy(model))</ span >
1154+ < span class ="sd "> >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)</ span >
11551155
11561156
1157- < span class ="sd "> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)</ span >
1158- < span class ="sd "> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)</ span >
1157+ < span class ="sd "> >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)</ span >
1158+ < span class ="sd "> >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)</ span >
11591159
1160- < span class ="sd "> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),</ span >
1161- < span class ="sd "> # the asserts will fail.</ span >
1162- < span class ="sd "> assert ddp_state_dict == fsdp_state_dict</ span >
1163- < span class ="sd "> assert ddp_optim_state == fsdp_optim_state_dict</ span >
1160+ < span class ="sd "> >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),</ span >
1161+ < span class ="sd "> >>> # the asserts will fail.</ span >
1162+ < span class ="sd "> >>> assert ddp_state_dict == fsdp_state_dict</ span >
1163+ < span class ="sd "> >>> assert ddp_optim_state == fsdp_optim_state_dict</ span >
11641164
11651165
11661166< span class ="sd "> Args:</ span >
@@ -1175,6 +1175,8 @@ <h1>Source code for torch.distributed.checkpoint.state_dict</h1><div class="high
11751175
11761176< span class ="sd "> Returns:</ span >
11771177< span class ="sd "> ``Tuple`` that contain model state_dict and optimizer state_dict.</ span >
1178+
1179+ < span class ="sd "> :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType]</ span >
11781180< span class ="sd "> """</ span >
11791181
11801182 < span class ="k "> with</ span > < span class ="n "> gc_context</ span > < span class ="p "> ():</ span >
0 commit comments