mirror of https://github.com/hpcaitech/ColossalAI
[gemini] fix colo_init_context (#2683)
parent
5cd8cae0c9
commit
f0aa191f51
|
@ -32,7 +32,7 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
|
||||||
default_pg: Optional[ProcessGroup] = None,
|
default_pg: Optional[ProcessGroup] = None,
|
||||||
default_dist_spec: Optional[Any] = None) -> ColoParameter:
|
default_dist_spec: Optional[Any] = None) -> ColoParameter:
|
||||||
|
|
||||||
if isinstance(param, ColoParameter):
|
if type(param) is ColoParameter:
|
||||||
return param
|
return param
|
||||||
# detaching tensor is necessary for optimizers.
|
# detaching tensor is necessary for optimizers.
|
||||||
requires_grad = param.requires_grad
|
requires_grad = param.requires_grad
|
||||||
|
@ -102,7 +102,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
"""
|
"""
|
||||||
name_list = []
|
name_list = []
|
||||||
for name, param in _named_params_with_replica(module):
|
for name, param in _named_params_with_replica(module):
|
||||||
if isinstance(param, ColoTensor):
|
if type(param) is ColoParameter:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
split = name.rfind('.')
|
split = name.rfind('.')
|
||||||
|
|
Loading…
Reference in New Issue