diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 0eada7583..a39353b16 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -69,7 +69,7 @@ def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.n d.discard(name) params = self.__dict__.get('_parameters') - if isinstance(value, (ColoTensor, torch.nn.Parameter)): + if isinstance(value, (ColoParameter, torch.nn.Parameter)): if params is None: raise AttributeError("cannot assign parameters before Module.__init__() call") remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)