From b3f73ce1c89f3ea9c1d3b4a432f1bc12302a49a4 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Mon, 19 Dec 2022 22:37:07 +0800 Subject: [PATCH] [Gemini] Update coloinit_ctx to support meta_tensor (#2147) --- colossalai/utils/model/colo_init_context.py | 51 ++++++++++++++++----- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 6cb885321..93c91e099 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -36,13 +36,13 @@ def _convert_to_coloparam(param: torch.nn.Parameter, return param # detaching tensor is necessary for optimizers. requires_grad = param.requires_grad - - if param.device.type == 'meta': - raise NotImplementedError( - "ColoInitContext is initializing a model with meta parameters! This is not allowed right now!") - else: - # param is the global tensor. + # param is the global tensor. + + if param.device.type == "meta": + colo_param = ColoParameter(param, requires_grad=requires_grad) + else: colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad) + # if default_shard_plan exists, shard the param during initialization. # This can reduce the model size after initialization. @@ -129,9 +129,32 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): delattr(submodule, param_name) setattr(submodule, param_name, colo_param) colo_param.shared_param_modules.append(submodule) - - module.to(self._device) - + + meta_param_flag = 0 + meta_buffer_flag = 0 + for param in module.parameters(): + if param.device.type=="meta": + meta_param_flag = 1 + if meta_param_flag == 1 and param.device.type!="meta": + raise ValueError("Meta parameters and valued parameters can not be in the same model") + + for buffer in module.buffers(): + if buffer.device.type=="meta": + meta_buffer_flag = 1 + if meta_buffer_flag == 1 and buffer.device.type!="meta": + raise ValueError("Meta buffers and valued buffers can not be in the same model") + + if meta_param_flag==1 and meta_buffer_flag==1: + pass + elif meta_buffer_flag==0 and meta_param_flag==1: + for name, buf in module.named_buffers(): + module._buffers[name] = module._buffers[name].to(device=self._device) + elif meta_param_flag==0 and meta_buffer_flag==1: + for name, param in module.named_parameters(): + module._parameters[name] = module._parameters[name].to(device=self._device) + else: + module.to(self._device) + def post_process_colo_init_ctx(model: torch.nn.Module, device: torch.device = torch.device('cpu'), @@ -156,12 +179,16 @@ def post_process_colo_init_ctx(model: torch.nn.Module, torch_params = [] for n, p in model.named_parameters(): if not isinstance(p, ColoParameter): - print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter") + # print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter") torch_params.append((n, p)) for (n, param) in torch_params: - delattr(model, n) - setattr(model, n, _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec)) + name_list = n.split('.') + module = model + for i in range(len(name_list) - 1): + module = module._modules[name_list[i]] + delattr(module, name_list[-1]) + setattr(module, name_list[-1], _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec)) del torch_params for n, p in model.named_parameters():