[Tensor] add from_pretrained support and bert pretrained test (#921)

* add from_pretrained support and test

* polish

* polish

* polish

* polish
pull/922/head
Ziyue Jiang 2022-05-09 16:11:47 +08:00 committed by GitHub
parent 1d625fcd36
commit c195d2814c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 158 additions and 20 deletions

View File

@ -1,7 +1,7 @@
from .op_wrapper import _COLOSSAL_OPS
import torch
from typing import Tuple, Optional, Callable
from typing import Tuple, Optional, Callable, Union
from numpy import product
from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import divide
@ -55,6 +55,15 @@ class ColoTensor(object):
def data(self):
return self._torch_tensor.data
@data.setter
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
if isinstance(tensor, ColoTensor):
self._torch_tensor.data = tensor.data
elif isinstance(tensor, torch.Tensor):
self._torch_tensor.data = tensor
else:
raise NotImplementedError
@property
def grad(self):
return self._torch_tensor.grad
@ -148,14 +157,31 @@ class ColoTensor(object):
assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
if self._shard_pattern == ShardPattern.Row:
dim = 0
elif self._shard_pattern == ShardPattern.Col:
dim = -1
dim = self._get_gather_dim()
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
self._shard_pattern = ShardPattern.NA
self._size = self._torch_tensor.size()
def global_torch_tensor(self) -> torch.Tensor:
out_tensor = self.torch_tensor()
if self.is_gathered():
return out_tensor
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
world_size = gpc.get_world_size(parallel_action.parallel_mode)
if world_size == 1:
return out_tensor
rank = gpc.get_local_rank(parallel_action.parallel_mode)
tensor_list = [torch.empty_like(out_tensor) for _ in range(world_size)]
tensor_list[rank] = out_tensor
torch.distributed.all_gather(tensor_list, out_tensor, group=gpc.get_group(parallel_action.parallel_mode))
dim = self._get_gather_dim()
out_tensor = torch.cat(tensor_list, dim=dim).contiguous()
return out_tensor
def is_gathered(self) -> bool:
return self._shard_pattern == ShardPattern.NA
@ -212,9 +238,7 @@ class ColoTensor(object):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
def __getattr__(self, name):
def replace_tensor_with_colo(func):
def execute_func(*args, **kwargs):
# transform the ColoTensor args to torch Tensor.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
@ -225,7 +249,9 @@ class ColoTensor(object):
return execute_func
assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor"
if hasattr(self._torch_tensor, name) == False:
raise AttributeError
attr = getattr(self._torch_tensor, name)
if isinstance(attr, Callable):
@ -244,3 +270,12 @@ class ColoTensor(object):
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
for output in outputs
])
def _get_gather_dim(self):
if self._shard_pattern == ShardPattern.Row:
dim = 0
elif self._shard_pattern == ShardPattern.Col:
dim = -1
else:
raise NotImplementedError
return dim

View File

@ -4,8 +4,86 @@ from colossalai.tensor import ColoTensor, ColoParameter
import types
from torch import nn
from typing import Iterator, Tuple, Union
from typing import Iterator, Tuple, Union, Optional
# Adapted from torch.nn.module.Module.register_param
def _register_parameter_with_colotensor(self, name: str, param):
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
if not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
if '.' in name:
raise KeyError("parameter name can't contain \".\"")
if name == '':
raise KeyError("parameter name can't be empty string \"\"")
if hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
if param is None:
self._parameters[name] = None
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or ColoParameter or None required)"
.format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError(
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
# Adapted from torch.nn.module.Module.__setattr__
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
params = self.__dict__.get('_parameters')
if isinstance(value, (ColoTensor, torch.nn.Parameter)):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, torch.nn.Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
def ColoModulize(module):
"""
@ -64,7 +142,6 @@ def ColoModulize(module):
module.colo_named_parameters = funcType(colo_named_parameters, module)
module._colo_visited = True
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
@ -77,11 +154,16 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device
# TODO(jzy) replace it with old __setattr__ in the exit() of context?
torch.nn.Module.__setattr__ = _setattr_with_colotensor
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?
"""
if hasattr(module, '_colo_visited'):
return
@ -100,7 +182,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
tensor_detached = param.to(self._device).detach()
tensor_detached.requires_grad = requires_grad
setattr(module, name,
ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload))
colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)
setattr(module, name, colo_param)
ColoModulize(module)
ColoModulize(module)

View File

@ -23,7 +23,7 @@ from transformers.file_utils import ModelOutput
from dataclasses import fields
def _post_init_colo(self):
def _post_init_colotensor(self):
class_fields = fields(self)
# Safety and consistency checks
if len(class_fields) == 0:
@ -72,7 +72,7 @@ def _post_init_colo(self):
self[field.name] = v
ModelOutput.__post_init__ = _post_init_colo
ModelOutput.__post_init__ = _post_init_colotensor
# complete the hack
@ -278,6 +278,26 @@ def test_colo_optimizer():
if i > 5:
break
def _test_pretrained():
from _utils import check_equal
from transformers import BertForMaskedLM
set_seed(1)
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model_pretrained = model_pretrained.cuda()
model = model.cuda()
dict_pretrained = {}
dict_col = {}
for name, param in model_pretrained.named_parameters():
dict_pretrained[name] = param
for name, param in model.named_parameters():
dict_col[name] = param
for name, param in dict_pretrained.items():
check_equal(param, dict_col[name])
def run_1d_row_tp(model_name: str):
# A simple net with two stacked nn.Linear
@ -377,4 +397,5 @@ def test_model(world_size):
if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
test_model()
# test_model()
_test_pretrained()

View File

@ -1,6 +1,6 @@
from numpy import allclose
import torch
from colossalai.tensor import ColoTensor
from colossalai.tensor import ColoTensor, ColoParameter
from copy import deepcopy
from colossalai.utils import get_current_device
@ -16,7 +16,7 @@ def test_layernorm():
delattr(ln_op_colo, 'weight')
weight_clone = ln_op.weight.clone().detach()
weight_clone.requires_grad = True
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))
setattr(ln_op_colo, 'weight', ColoParameter.init_from_torch_tensor(tensor=weight_clone))
output = ln_op(input_t)
output_colo = ln_op_colo(input_t_colo)
@ -39,8 +39,8 @@ def test_linear():
input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone()
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
sharded_weight = ColoParameter.init_from_torch_tensor(fc_ref.weight)
sharded_bias = ColoParameter.init_from_torch_tensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight')