diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index ab354ea70..87ae413a2 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -32,7 +32,7 @@ def _convert_to_coloparam(param: torch.nn.Parameter, default_pg: Optional[ProcessGroup] = None, default_dist_spec: Optional[Any] = None) -> ColoParameter: - if isinstance(param, ColoParameter): + if type(param) is ColoParameter: return param # detaching tensor is necessary for optimizers. requires_grad = param.requires_grad @@ -102,7 +102,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): """ name_list = [] for name, param in _named_params_with_replica(module): - if isinstance(param, ColoTensor): + if type(param) is ColoParameter: continue split = name.rfind('.')