[Tensor] add from_pretrained support and bert pretrained test (#921)

* add from_pretrained support and test

* polish

* polish

* polish

* polish
pull/922/head
Ziyue Jiang 2022-05-09 16:11:47 +08:00 committed by GitHub
parent 1d625fcd36
commit c195d2814c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 158 additions and 20 deletions

View File

@ -1,7 +1,7 @@
from .op_wrapper import _COLOSSAL_OPS from .op_wrapper import _COLOSSAL_OPS
import torch import torch
from typing import Tuple, Optional, Callable from typing import Tuple, Optional, Callable, Union
from numpy import product from numpy import product
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
@ -55,6 +55,15 @@ class ColoTensor(object):
def data(self): def data(self):
return self._torch_tensor.data return self._torch_tensor.data
@data.setter
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
if isinstance(tensor, ColoTensor):
self._torch_tensor.data = tensor.data
elif isinstance(tensor, torch.Tensor):
self._torch_tensor.data = tensor
else:
raise NotImplementedError
@property @property
def grad(self): def grad(self):
return self._torch_tensor.grad return self._torch_tensor.grad
@ -148,14 +157,31 @@ class ColoTensor(object):
assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.' assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.' assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP) parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
if self._shard_pattern == ShardPattern.Row: dim = self._get_gather_dim()
dim = 0
elif self._shard_pattern == ShardPattern.Col:
dim = -1
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim) self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
self._shard_pattern = ShardPattern.NA self._shard_pattern = ShardPattern.NA
self._size = self._torch_tensor.size() self._size = self._torch_tensor.size()
def global_torch_tensor(self) -> torch.Tensor:
out_tensor = self.torch_tensor()
if self.is_gathered():
return out_tensor
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
world_size = gpc.get_world_size(parallel_action.parallel_mode)
if world_size == 1:
return out_tensor
rank = gpc.get_local_rank(parallel_action.parallel_mode)
tensor_list = [torch.empty_like(out_tensor) for _ in range(world_size)]
tensor_list[rank] = out_tensor
torch.distributed.all_gather(tensor_list, out_tensor, group=gpc.get_group(parallel_action.parallel_mode))
dim = self._get_gather_dim()
out_tensor = torch.cat(tensor_list, dim=dim).contiguous()
return out_tensor
def is_gathered(self) -> bool: def is_gathered(self) -> bool:
return self._shard_pattern == ShardPattern.NA return self._shard_pattern == ShardPattern.NA
@ -212,9 +238,7 @@ class ColoTensor(object):
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o) return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
def __getattr__(self, name): def __getattr__(self, name):
def replace_tensor_with_colo(func): def replace_tensor_with_colo(func):
def execute_func(*args, **kwargs): def execute_func(*args, **kwargs):
# transform the ColoTensor args to torch Tensor. # transform the ColoTensor args to torch Tensor.
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args] args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
@ -225,7 +249,9 @@ class ColoTensor(object):
return execute_func return execute_func
assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor" if hasattr(self._torch_tensor, name) == False:
raise AttributeError
attr = getattr(self._torch_tensor, name) attr = getattr(self._torch_tensor, name)
if isinstance(attr, Callable): if isinstance(attr, Callable):
@ -244,3 +270,12 @@ class ColoTensor(object):
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
for output in outputs for output in outputs
]) ])
def _get_gather_dim(self):
if self._shard_pattern == ShardPattern.Row:
dim = 0
elif self._shard_pattern == ShardPattern.Col:
dim = -1
else:
raise NotImplementedError
return dim

View File

@ -4,8 +4,86 @@ from colossalai.tensor import ColoTensor, ColoParameter
import types import types
from torch import nn from torch import nn
from typing import Iterator, Tuple, Union from typing import Iterator, Tuple, Union, Optional
# Adapted from torch.nn.module.Module.register_param
def _register_parameter_with_colotensor(self, name: str, param):
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
if not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
if '.' in name:
raise KeyError("parameter name can't contain \".\"")
if name == '':
raise KeyError("parameter name can't be empty string \"\"")
if hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
if param is None:
self._parameters[name] = None
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or ColoParameter or None required)"
.format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError(
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
# Adapted from torch.nn.module.Module.__setattr__
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
params = self.__dict__.get('_parameters')
if isinstance(value, (ColoTensor, torch.nn.Parameter)):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, torch.nn.Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)"
.format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)
def ColoModulize(module): def ColoModulize(module):
""" """
@ -64,7 +142,6 @@ def ColoModulize(module):
module.colo_named_parameters = funcType(colo_named_parameters, module) module.colo_named_parameters = funcType(colo_named_parameters, module)
module._colo_visited = True module._colo_visited = True
class ColoInitContext(InsertPostInitMethodToModuleSubClasses): class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
@ -77,11 +154,16 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self._lazy_memory_allocate = lazy_memory_allocate self._lazy_memory_allocate = lazy_memory_allocate
self._device = device self._device = device
# TODO(jzy) replace it with old __setattr__ in the exit() of context?
torch.nn.Module.__setattr__ = _setattr_with_colotensor
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
""" """
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times? FIXME(fjr) The module may be passed to this function multiple times?
""" """
if hasattr(module, '_colo_visited'): if hasattr(module, '_colo_visited'):
return return
@ -100,7 +182,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
tensor_detached = param.to(self._device).detach() tensor_detached = param.to(self._device).detach()
tensor_detached.requires_grad = requires_grad tensor_detached.requires_grad = requires_grad
setattr(module, name, colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)
ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)) setattr(module, name, colo_param)
ColoModulize(module) ColoModulize(module)

View File

@ -23,7 +23,7 @@ from transformers.file_utils import ModelOutput
from dataclasses import fields from dataclasses import fields
def _post_init_colo(self): def _post_init_colotensor(self):
class_fields = fields(self) class_fields = fields(self)
# Safety and consistency checks # Safety and consistency checks
if len(class_fields) == 0: if len(class_fields) == 0:
@ -72,7 +72,7 @@ def _post_init_colo(self):
self[field.name] = v self[field.name] = v
ModelOutput.__post_init__ = _post_init_colo ModelOutput.__post_init__ = _post_init_colotensor
# complete the hack # complete the hack
@ -278,6 +278,26 @@ def test_colo_optimizer():
if i > 5: if i > 5:
break break
def _test_pretrained():
from _utils import check_equal
from transformers import BertForMaskedLM
set_seed(1)
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model_pretrained = model_pretrained.cuda()
model = model.cuda()
dict_pretrained = {}
dict_col = {}
for name, param in model_pretrained.named_parameters():
dict_pretrained[name] = param
for name, param in model.named_parameters():
dict_col[name] = param
for name, param in dict_pretrained.items():
check_equal(param, dict_col[name])
def run_1d_row_tp(model_name: str): def run_1d_row_tp(model_name: str):
# A simple net with two stacked nn.Linear # A simple net with two stacked nn.Linear
@ -377,4 +397,5 @@ def test_model(world_size):
if __name__ == '__main__': if __name__ == '__main__':
# test_model_parameters() # test_model_parameters()
# test_colo_optimizer() # test_colo_optimizer()
test_model() # test_model()
_test_pretrained()

View File

@ -1,6 +1,6 @@
from numpy import allclose from numpy import allclose
import torch import torch
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor, ColoParameter
from copy import deepcopy from copy import deepcopy
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -16,7 +16,7 @@ def test_layernorm():
delattr(ln_op_colo, 'weight') delattr(ln_op_colo, 'weight')
weight_clone = ln_op.weight.clone().detach() weight_clone = ln_op.weight.clone().detach()
weight_clone.requires_grad = True weight_clone.requires_grad = True
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone)) setattr(ln_op_colo, 'weight', ColoParameter.init_from_torch_tensor(tensor=weight_clone))
output = ln_op(input_t) output = ln_op(input_t)
output_colo = ln_op_colo(input_t_colo) output_colo = ln_op_colo(input_t_colo)
@ -39,8 +39,8 @@ def test_linear():
input_ref = torch.randn(1, in_dim) input_ref = torch.randn(1, in_dim)
input_tensor = input_ref.clone() input_tensor = input_ref.clone()
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight) sharded_weight = ColoParameter.init_from_torch_tensor(fc_ref.weight)
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias) sharded_bias = ColoParameter.init_from_torch_tensor(fc_ref.bias)
# replace the torch nn.Parameters with ShardedTensor # replace the torch nn.Parameters with ShardedTensor
delattr(fc, 'weight') delattr(fc, 'weight')