Browse Source

[gemini] fix colo_init_context (#2683)

pull/2686/head
ver217 2 years ago committed by GitHub
parent
commit
f0aa191f51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      colossalai/utils/model/colo_init_context.py

4
colossalai/utils/model/colo_init_context.py

@ -32,7 +32,7 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
default_pg: Optional[ProcessGroup] = None, default_pg: Optional[ProcessGroup] = None,
default_dist_spec: Optional[Any] = None) -> ColoParameter: default_dist_spec: Optional[Any] = None) -> ColoParameter:
if isinstance(param, ColoParameter): if type(param) is ColoParameter:
return param return param
# detaching tensor is necessary for optimizers. # detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad requires_grad = param.requires_grad
@ -102,7 +102,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
""" """
name_list = [] name_list = []
for name, param in _named_params_with_replica(module): for name, param in _named_params_with_replica(module):
if isinstance(param, ColoTensor): if type(param) is ColoParameter:
continue continue
split = name.rfind('.') split = name.rfind('.')

Loading…
Cancel
Save