From c195d2814cd31b78f3a3af2f5d6f78b2df96be1b Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Mon, 9 May 2022 16:11:47 +0800 Subject: [PATCH] [Tensor] add from_pretrained support and bert pretrained test (#921) * add from_pretrained support and test * polish * polish * polish * polish --- colossalai/tensor/colo_tensor.py | 51 ++++++++++-- colossalai/utils/model/colo_init_context.py | 92 +++++++++++++++++++-- tests/test_tensor/test_model.py | 27 +++++- tests/test_tensor/test_op.py | 8 +- 4 files changed, 158 insertions(+), 20 deletions(-) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 7dc1e78f7..f3a542ff6 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,7 +1,7 @@ from .op_wrapper import _COLOSSAL_OPS import torch -from typing import Tuple, Optional, Callable +from typing import Tuple, Optional, Callable, Union from numpy import product from colossalai.core import global_context as gpc from colossalai.nn.layer.utils import divide @@ -55,6 +55,15 @@ class ColoTensor(object): def data(self): return self._torch_tensor.data + @data.setter + def data(self, tensor: Union[torch.Tensor, "ColoTensor"]): + if isinstance(tensor, ColoTensor): + self._torch_tensor.data = tensor.data + elif isinstance(tensor, torch.Tensor): + self._torch_tensor.data = tensor + else: + raise NotImplementedError + @property def grad(self): return self._torch_tensor.grad @@ -148,14 +157,31 @@ class ColoTensor(object): assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.' assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.' parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP) - if self._shard_pattern == ShardPattern.Row: - dim = 0 - elif self._shard_pattern == ShardPattern.Col: - dim = -1 + dim = self._get_gather_dim() self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim) self._shard_pattern = ShardPattern.NA self._size = self._torch_tensor.size() + def global_torch_tensor(self) -> torch.Tensor: + out_tensor = self.torch_tensor() + if self.is_gathered(): + return out_tensor + + parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP) + world_size = gpc.get_world_size(parallel_action.parallel_mode) + if world_size == 1: + return out_tensor + + rank = gpc.get_local_rank(parallel_action.parallel_mode) + tensor_list = [torch.empty_like(out_tensor) for _ in range(world_size)] + tensor_list[rank] = out_tensor + torch.distributed.all_gather(tensor_list, out_tensor, group=gpc.get_group(parallel_action.parallel_mode)) + + dim = self._get_gather_dim() + out_tensor = torch.cat(tensor_list, dim=dim).contiguous() + + return out_tensor + def is_gathered(self) -> bool: return self._shard_pattern == ShardPattern.NA @@ -212,9 +238,7 @@ class ColoTensor(object): return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o) def __getattr__(self, name): - def replace_tensor_with_colo(func): - def execute_func(*args, **kwargs): # transform the ColoTensor args to torch Tensor. args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args] @@ -225,7 +249,9 @@ class ColoTensor(object): return execute_func - assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor" + if hasattr(self._torch_tensor, name) == False: + raise AttributeError + attr = getattr(self._torch_tensor, name) if isinstance(attr, Callable): @@ -244,3 +270,12 @@ class ColoTensor(object): ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output for output in outputs ]) + + def _get_gather_dim(self): + if self._shard_pattern == ShardPattern.Row: + dim = 0 + elif self._shard_pattern == ShardPattern.Col: + dim = -1 + else: + raise NotImplementedError + return dim diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 5853e369d..877c8428c 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -4,8 +4,86 @@ from colossalai.tensor import ColoTensor, ColoParameter import types from torch import nn -from typing import Iterator, Tuple, Union +from typing import Iterator, Tuple, Union, Optional +# Adapted from torch.nn.module.Module.register_param +def _register_parameter_with_colotensor(self, name: str, param): + if '_parameters' not in self.__dict__: + raise AttributeError( + "cannot assign parameter before Module.__init__() call") + + if not isinstance(name, torch._six.string_classes): + raise TypeError("parameter name should be a string. " + "Got {}".format(torch.typename(name))) + if '.' in name: + raise KeyError("parameter name can't contain \".\"") + if name == '': + raise KeyError("parameter name can't be empty string \"\"") + if hasattr(self, name) and name not in self._parameters: + raise KeyError("attribute '{}' already exists".format(name)) + + if param is None: + self._parameters[name] = None + elif not isinstance(param, (torch.nn.Parameter, ColoParameter)): + raise TypeError("cannot assign '{}' object to parameter '{}' " + "(torch.nn.Parameter or ColoParameter or None required)" + .format(torch.typename(param), name)) + elif param.grad_fn: + raise ValueError( + "Cannot assign non-leaf Tensor to parameter '{0}'. Model " + "parameters must be created explicitly. To express '{0}' " + "as a function of another Tensor, compute the value in " + "the forward() method.".format(name)) + else: + self._parameters[name] = param + +# Adapted from torch.nn.module.Module.__setattr__ +def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]): + def remove_from(*dicts_or_sets): + for d in dicts_or_sets: + if name in d: + if isinstance(d, dict): + del d[name] + else: + d.discard(name) + + params = self.__dict__.get('_parameters') + if isinstance(value, (ColoTensor, torch.nn.Parameter)): + if params is None: + raise AttributeError( + "cannot assign parameters before Module.__init__() call") + remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set) + self.register_parameter(name, value) + elif params is not None and name in params: + if value is not None: + raise TypeError("cannot assign '{}' as parameter '{}' " + "(torch.nn.Parameter or None expected)" + .format(torch.typename(value), name)) + self.register_parameter(name, value) + else: + modules = self.__dict__.get('_modules') + if isinstance(value, torch.nn.Module): + if modules is None: + raise AttributeError( + "cannot assign module before Module.__init__() call") + remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set) + modules[name] = value + elif modules is not None and name in modules: + if value is not None: + raise TypeError("cannot assign '{}' as child module '{}' " + "(torch.nn.Module or None expected)" + .format(torch.typename(value), name)) + modules[name] = value + else: + buffers = self.__dict__.get('_buffers') + if buffers is not None and name in buffers: + if value is not None and not isinstance(value, torch.Tensor): + raise TypeError("cannot assign '{}' as buffer '{}' " + "(torch.Tensor or None expected)" + .format(torch.typename(value), name)) + buffers[name] = value + else: + object.__setattr__(self, name, value) def ColoModulize(module): """ @@ -64,7 +142,6 @@ def ColoModulize(module): module.colo_named_parameters = funcType(colo_named_parameters, module) module._colo_visited = True - class ColoInitContext(InsertPostInitMethodToModuleSubClasses): def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): @@ -77,11 +154,16 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): self._lazy_memory_allocate = lazy_memory_allocate self._device = device + # TODO(jzy) replace it with old __setattr__ in the exit() of context? + torch.nn.Module.__setattr__ = _setattr_with_colotensor + torch.nn.Module.register_parameter = _register_parameter_with_colotensor + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ The function to call at the end of the constructor of each module. FIXME(fjr) The module may be passed to this function multiple times? """ + if hasattr(module, '_colo_visited'): return @@ -100,7 +182,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): tensor_detached = param.to(self._device).detach() tensor_detached.requires_grad = requires_grad - setattr(module, name, - ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)) + colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload) + setattr(module, name, colo_param) - ColoModulize(module) + ColoModulize(module) \ No newline at end of file diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index aabf4c7f6..c75242100 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -23,7 +23,7 @@ from transformers.file_utils import ModelOutput from dataclasses import fields -def _post_init_colo(self): +def _post_init_colotensor(self): class_fields = fields(self) # Safety and consistency checks if len(class_fields) == 0: @@ -72,7 +72,7 @@ def _post_init_colo(self): self[field.name] = v -ModelOutput.__post_init__ = _post_init_colo +ModelOutput.__post_init__ = _post_init_colotensor # complete the hack @@ -278,6 +278,26 @@ def test_colo_optimizer(): if i > 5: break +def _test_pretrained(): + from _utils import check_equal + from transformers import BertForMaskedLM + set_seed(1) + model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') + with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + model = BertForMaskedLM.from_pretrained('bert-base-uncased') + + model_pretrained = model_pretrained.cuda() + model = model.cuda() + + dict_pretrained = {} + dict_col = {} + for name, param in model_pretrained.named_parameters(): + dict_pretrained[name] = param + for name, param in model.named_parameters(): + dict_col[name] = param + + for name, param in dict_pretrained.items(): + check_equal(param, dict_col[name]) def run_1d_row_tp(model_name: str): # A simple net with two stacked nn.Linear @@ -377,4 +397,5 @@ def test_model(world_size): if __name__ == '__main__': # test_model_parameters() # test_colo_optimizer() - test_model() + # test_model() + _test_pretrained() diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 4babb73cd..233f1bdcb 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -1,6 +1,6 @@ from numpy import allclose import torch -from colossalai.tensor import ColoTensor +from colossalai.tensor import ColoTensor, ColoParameter from copy import deepcopy from colossalai.utils import get_current_device @@ -16,7 +16,7 @@ def test_layernorm(): delattr(ln_op_colo, 'weight') weight_clone = ln_op.weight.clone().detach() weight_clone.requires_grad = True - setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone)) + setattr(ln_op_colo, 'weight', ColoParameter.init_from_torch_tensor(tensor=weight_clone)) output = ln_op(input_t) output_colo = ln_op_colo(input_t_colo) @@ -39,8 +39,8 @@ def test_linear(): input_ref = torch.randn(1, in_dim) input_tensor = input_ref.clone() - sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight) - sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias) + sharded_weight = ColoParameter.init_from_torch_tensor(fc_ref.weight) + sharded_bias = ColoParameter.init_from_torch_tensor(fc_ref.bias) # replace the torch nn.Parameters with ShardedTensor delattr(fc, 'weight')