[Gemini] Update coloinit_ctx to support meta_tensor (#2147)

pull/2150/head^2
BlueRum 2022-12-19 22:37:07 +08:00 committed by GitHub
parent 6ad866b684
commit b3f73ce1c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 39 additions and 12 deletions

View File

@ -36,13 +36,13 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
return param
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
if param.device.type == 'meta':
raise NotImplementedError(
"ColoInitContext is initializing a model with meta parameters! This is not allowed right now!")
else:
# param is the global tensor.
# param is the global tensor.
if param.device.type == "meta":
colo_param = ColoParameter(param, requires_grad=requires_grad)
else:
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
@ -129,9 +129,32 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr(submodule, param_name)
setattr(submodule, param_name, colo_param)
colo_param.shared_param_modules.append(submodule)
module.to(self._device)
meta_param_flag = 0
meta_buffer_flag = 0
for param in module.parameters():
if param.device.type=="meta":
meta_param_flag = 1
if meta_param_flag == 1 and param.device.type!="meta":
raise ValueError("Meta parameters and valued parameters can not be in the same model")
for buffer in module.buffers():
if buffer.device.type=="meta":
meta_buffer_flag = 1
if meta_buffer_flag == 1 and buffer.device.type!="meta":
raise ValueError("Meta buffers and valued buffers can not be in the same model")
if meta_param_flag==1 and meta_buffer_flag==1:
pass
elif meta_buffer_flag==0 and meta_param_flag==1:
for name, buf in module.named_buffers():
module._buffers[name] = module._buffers[name].to(device=self._device)
elif meta_param_flag==0 and meta_buffer_flag==1:
for name, param in module.named_parameters():
module._parameters[name] = module._parameters[name].to(device=self._device)
else:
module.to(self._device)
def post_process_colo_init_ctx(model: torch.nn.Module,
device: torch.device = torch.device('cpu'),
@ -156,12 +179,16 @@ def post_process_colo_init_ctx(model: torch.nn.Module,
torch_params = []
for n, p in model.named_parameters():
if not isinstance(p, ColoParameter):
print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter")
# print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter")
torch_params.append((n, p))
for (n, param) in torch_params:
delattr(model, n)
setattr(model, n, _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec))
name_list = n.split('.')
module = model
for i in range(len(name_list) - 1):
module = module._modules[name_list[i]]
delattr(module, name_list[-1])
setattr(module, name_list[-1], _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec))
del torch_params
for n, p in model.named_parameters():