[ColoTensor] throw error when ColoInitContext meets meta parameter. (#2105)

pull/2106/head^2
Jiarui Fang 2022-12-09 11:39:46 +08:00 committed by GitHub
parent d87baa85d9
commit 05545bfee9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 2 deletions

View File

@ -36,8 +36,13 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
return param
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# param is the global tensor.
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
if param.device.type == 'meta':
raise NotImplemented(
"ColoInitContext is initializing a model with meta parameters! This is not allowed right now!")
else:
# param is the global tensor.
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.