mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] Update coloinit_ctx to support meta_tensor (#2147)
parent
6ad866b684
commit
b3f73ce1c8
|
@ -36,13 +36,13 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
|
|||
return param
|
||||
# detaching tensor is necessary for optimizers.
|
||||
requires_grad = param.requires_grad
|
||||
|
||||
if param.device.type == 'meta':
|
||||
raise NotImplementedError(
|
||||
"ColoInitContext is initializing a model with meta parameters! This is not allowed right now!")
|
||||
else:
|
||||
# param is the global tensor.
|
||||
# param is the global tensor.
|
||||
|
||||
if param.device.type == "meta":
|
||||
colo_param = ColoParameter(param, requires_grad=requires_grad)
|
||||
else:
|
||||
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
|
||||
|
||||
|
||||
# if default_shard_plan exists, shard the param during initialization.
|
||||
# This can reduce the model size after initialization.
|
||||
|
@ -129,9 +129,32 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
delattr(submodule, param_name)
|
||||
setattr(submodule, param_name, colo_param)
|
||||
colo_param.shared_param_modules.append(submodule)
|
||||
|
||||
module.to(self._device)
|
||||
|
||||
|
||||
meta_param_flag = 0
|
||||
meta_buffer_flag = 0
|
||||
for param in module.parameters():
|
||||
if param.device.type=="meta":
|
||||
meta_param_flag = 1
|
||||
if meta_param_flag == 1 and param.device.type!="meta":
|
||||
raise ValueError("Meta parameters and valued parameters can not be in the same model")
|
||||
|
||||
for buffer in module.buffers():
|
||||
if buffer.device.type=="meta":
|
||||
meta_buffer_flag = 1
|
||||
if meta_buffer_flag == 1 and buffer.device.type!="meta":
|
||||
raise ValueError("Meta buffers and valued buffers can not be in the same model")
|
||||
|
||||
if meta_param_flag==1 and meta_buffer_flag==1:
|
||||
pass
|
||||
elif meta_buffer_flag==0 and meta_param_flag==1:
|
||||
for name, buf in module.named_buffers():
|
||||
module._buffers[name] = module._buffers[name].to(device=self._device)
|
||||
elif meta_param_flag==0 and meta_buffer_flag==1:
|
||||
for name, param in module.named_parameters():
|
||||
module._parameters[name] = module._parameters[name].to(device=self._device)
|
||||
else:
|
||||
module.to(self._device)
|
||||
|
||||
|
||||
def post_process_colo_init_ctx(model: torch.nn.Module,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
|
@ -156,12 +179,16 @@ def post_process_colo_init_ctx(model: torch.nn.Module,
|
|||
torch_params = []
|
||||
for n, p in model.named_parameters():
|
||||
if not isinstance(p, ColoParameter):
|
||||
print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter")
|
||||
# print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter")
|
||||
torch_params.append((n, p))
|
||||
|
||||
for (n, param) in torch_params:
|
||||
delattr(model, n)
|
||||
setattr(model, n, _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec))
|
||||
name_list = n.split('.')
|
||||
module = model
|
||||
for i in range(len(name_list) - 1):
|
||||
module = module._modules[name_list[i]]
|
||||
delattr(module, name_list[-1])
|
||||
setattr(module, name_list[-1], _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec))
|
||||
|
||||
del torch_params
|
||||
for n, p in model.named_parameters():
|
||||
|
|
Loading…
Reference in New Issue