diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 7aed1d471..5758cb8b3 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -6,6 +6,21 @@ import types from torch import nn from typing import Iterator, Tuple, Union, Optional +# find named_params includes replica +def _named_params_with_replica( + module: nn.Module, + prefix: str = '', + recurse: bool = True, + ) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + for mod_prefix, mod in modules: + for name, val in mod._parameters.items(): + if val is None: + continue + name = mod_prefix + ('.' if mod_prefix else '') + name + yield name, val + # Adapted from torch.nn.module.Module.register_param def _register_parameter_with_colotensor(self, name: str, param): if '_parameters' not in self.__dict__: @@ -139,21 +154,36 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): return name_list = [] - for name, param in module.named_parameters(recurse=False): + for name, param in _named_params_with_replica(module): if isinstance(param, ColoTensor): continue - name_list.append((name, param)) - - save_torch_payload = True if not self._lazy_memory_allocate else False - for name, param in name_list: - delattr(module, name) - # detaching tensor is necessary for optimizers. - requires_grad = param.requires_grad - tensor_detached = param.to(self._device).detach() - tensor_detached.requires_grad = requires_grad - - colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload) - setattr(module, name, colo_param) + split = name.rfind('.') + if split >= 0: # param in submodule + module_name = name[:split] + param_name = name[split+1:] + else: + module_name = '' # param in current module + param_name = name + name_list.append((module_name, param_name)) + + replaced_tensors = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference + for module_name, param_name in name_list: + submodule = module.get_submodule(module_name) + param = submodule.get_parameter(param_name) + if param in replaced_tensors: + colo_param = replaced_tensors[param] + else: + save_torch_payload = True if not self._lazy_memory_allocate else False + # detaching tensor is necessary for optimizers. + requires_grad = param.requires_grad + tensor_detached = param.to(self._device).detach() + tensor_detached.requires_grad = requires_grad + + colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload) + # add mapping record + replaced_tensors[param] = colo_param + delattr(submodule, param_name) + setattr(submodule, param_name, colo_param) ColoModulize(module) \ No newline at end of file diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 7fe850af1..2b5b120b1 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -370,16 +370,22 @@ def _run_pretrain_load(): dict_pretrained = {} dict_col = {} + c_ref = 0 for name, param in model_pretrained.named_parameters(): dict_pretrained[name] = param + c_ref += 1 c1 = 0 c2 = 0 for name, param in model.colo_named_parameters(): if isinstance(param, ColoParameter): - c1 = c1 + 1 + c1 += 1 else: - c2 = c2 + 1 + c2 +=1 dict_col[name] = param + assert c_ref == c1 + assert c2 == 0 + if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias: + assert model.cls.predictions.decoder.bias is model.cls.predictions.bias for name, param in dict_pretrained.items(): check_equal(param, dict_col[name]) @@ -423,5 +429,4 @@ if __name__ == '__main__': # test_model_parameters() # test_colo_optimizer() # test_model() - # _test_pretrain_load(4) - _run_pretrain_load() + _test_pretrain_load(4)