|
|
@ -32,7 +32,7 @@ def _convert_to_coloparam(param: torch.nn.Parameter, |
|
|
|
default_pg: Optional[ProcessGroup] = None, |
|
|
|
default_pg: Optional[ProcessGroup] = None, |
|
|
|
default_dist_spec: Optional[Any] = None) -> ColoParameter: |
|
|
|
default_dist_spec: Optional[Any] = None) -> ColoParameter: |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(param, ColoParameter): |
|
|
|
if type(param) is ColoParameter: |
|
|
|
return param |
|
|
|
return param |
|
|
|
# detaching tensor is necessary for optimizers. |
|
|
|
# detaching tensor is necessary for optimizers. |
|
|
|
requires_grad = param.requires_grad |
|
|
|
requires_grad = param.requires_grad |
|
|
@ -102,7 +102,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): |
|
|
|
""" |
|
|
|
""" |
|
|
|
name_list = [] |
|
|
|
name_list = [] |
|
|
|
for name, param in _named_params_with_replica(module): |
|
|
|
for name, param in _named_params_with_replica(module): |
|
|
|
if isinstance(param, ColoTensor): |
|
|
|
if type(param) is ColoParameter: |
|
|
|
continue |
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
split = name.rfind('.') |
|
|
|
split = name.rfind('.') |
|
|
|