mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] add from_pretrained support and bert pretrained test (#921)
* add from_pretrained support and test * polish * polish * polish * polishpull/922/head
parent
1d625fcd36
commit
c195d2814c
|
@ -1,7 +1,7 @@
|
|||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Optional, Callable
|
||||
from typing import Tuple, Optional, Callable, Union
|
||||
from numpy import product
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.utils import divide
|
||||
|
@ -55,6 +55,15 @@ class ColoTensor(object):
|
|||
def data(self):
|
||||
return self._torch_tensor.data
|
||||
|
||||
@data.setter
|
||||
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
|
||||
if isinstance(tensor, ColoTensor):
|
||||
self._torch_tensor.data = tensor.data
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
self._torch_tensor.data = tensor
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def grad(self):
|
||||
return self._torch_tensor.grad
|
||||
|
@ -148,14 +157,31 @@ class ColoTensor(object):
|
|||
assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
|
||||
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
||||
if self._shard_pattern == ShardPattern.Row:
|
||||
dim = 0
|
||||
elif self._shard_pattern == ShardPattern.Col:
|
||||
dim = -1
|
||||
dim = self._get_gather_dim()
|
||||
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
|
||||
self._shard_pattern = ShardPattern.NA
|
||||
self._size = self._torch_tensor.size()
|
||||
|
||||
def global_torch_tensor(self) -> torch.Tensor:
|
||||
out_tensor = self.torch_tensor()
|
||||
if self.is_gathered():
|
||||
return out_tensor
|
||||
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
||||
world_size = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
if world_size == 1:
|
||||
return out_tensor
|
||||
|
||||
rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
tensor_list = [torch.empty_like(out_tensor) for _ in range(world_size)]
|
||||
tensor_list[rank] = out_tensor
|
||||
torch.distributed.all_gather(tensor_list, out_tensor, group=gpc.get_group(parallel_action.parallel_mode))
|
||||
|
||||
dim = self._get_gather_dim()
|
||||
out_tensor = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return out_tensor
|
||||
|
||||
def is_gathered(self) -> bool:
|
||||
return self._shard_pattern == ShardPattern.NA
|
||||
|
||||
|
@ -212,9 +238,7 @@ class ColoTensor(object):
|
|||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
||||
def replace_tensor_with_colo(func):
|
||||
|
||||
def execute_func(*args, **kwargs):
|
||||
# transform the ColoTensor args to torch Tensor.
|
||||
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
||||
|
@ -225,7 +249,9 @@ class ColoTensor(object):
|
|||
|
||||
return execute_func
|
||||
|
||||
assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor"
|
||||
if hasattr(self._torch_tensor, name) == False:
|
||||
raise AttributeError
|
||||
|
||||
attr = getattr(self._torch_tensor, name)
|
||||
|
||||
if isinstance(attr, Callable):
|
||||
|
@ -244,3 +270,12 @@ class ColoTensor(object):
|
|||
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
|
||||
for output in outputs
|
||||
])
|
||||
|
||||
def _get_gather_dim(self):
|
||||
if self._shard_pattern == ShardPattern.Row:
|
||||
dim = 0
|
||||
elif self._shard_pattern == ShardPattern.Col:
|
||||
dim = -1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dim
|
||||
|
|
|
@ -4,8 +4,86 @@ from colossalai.tensor import ColoTensor, ColoParameter
|
|||
import types
|
||||
|
||||
from torch import nn
|
||||
from typing import Iterator, Tuple, Union
|
||||
from typing import Iterator, Tuple, Union, Optional
|
||||
|
||||
# Adapted from torch.nn.module.Module.register_param
|
||||
def _register_parameter_with_colotensor(self, name: str, param):
|
||||
if '_parameters' not in self.__dict__:
|
||||
raise AttributeError(
|
||||
"cannot assign parameter before Module.__init__() call")
|
||||
|
||||
if not isinstance(name, torch._six.string_classes):
|
||||
raise TypeError("parameter name should be a string. "
|
||||
"Got {}".format(torch.typename(name)))
|
||||
if '.' in name:
|
||||
raise KeyError("parameter name can't contain \".\"")
|
||||
if name == '':
|
||||
raise KeyError("parameter name can't be empty string \"\"")
|
||||
if hasattr(self, name) and name not in self._parameters:
|
||||
raise KeyError("attribute '{}' already exists".format(name))
|
||||
|
||||
if param is None:
|
||||
self._parameters[name] = None
|
||||
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
|
||||
raise TypeError("cannot assign '{}' object to parameter '{}' "
|
||||
"(torch.nn.Parameter or ColoParameter or None required)"
|
||||
.format(torch.typename(param), name))
|
||||
elif param.grad_fn:
|
||||
raise ValueError(
|
||||
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
|
||||
"parameters must be created explicitly. To express '{0}' "
|
||||
"as a function of another Tensor, compute the value in "
|
||||
"the forward() method.".format(name))
|
||||
else:
|
||||
self._parameters[name] = param
|
||||
|
||||
# Adapted from torch.nn.module.Module.__setattr__
|
||||
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
|
||||
def remove_from(*dicts_or_sets):
|
||||
for d in dicts_or_sets:
|
||||
if name in d:
|
||||
if isinstance(d, dict):
|
||||
del d[name]
|
||||
else:
|
||||
d.discard(name)
|
||||
|
||||
params = self.__dict__.get('_parameters')
|
||||
if isinstance(value, (ColoTensor, torch.nn.Parameter)):
|
||||
if params is None:
|
||||
raise AttributeError(
|
||||
"cannot assign parameters before Module.__init__() call")
|
||||
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
|
||||
self.register_parameter(name, value)
|
||||
elif params is not None and name in params:
|
||||
if value is not None:
|
||||
raise TypeError("cannot assign '{}' as parameter '{}' "
|
||||
"(torch.nn.Parameter or None expected)"
|
||||
.format(torch.typename(value), name))
|
||||
self.register_parameter(name, value)
|
||||
else:
|
||||
modules = self.__dict__.get('_modules')
|
||||
if isinstance(value, torch.nn.Module):
|
||||
if modules is None:
|
||||
raise AttributeError(
|
||||
"cannot assign module before Module.__init__() call")
|
||||
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
|
||||
modules[name] = value
|
||||
elif modules is not None and name in modules:
|
||||
if value is not None:
|
||||
raise TypeError("cannot assign '{}' as child module '{}' "
|
||||
"(torch.nn.Module or None expected)"
|
||||
.format(torch.typename(value), name))
|
||||
modules[name] = value
|
||||
else:
|
||||
buffers = self.__dict__.get('_buffers')
|
||||
if buffers is not None and name in buffers:
|
||||
if value is not None and not isinstance(value, torch.Tensor):
|
||||
raise TypeError("cannot assign '{}' as buffer '{}' "
|
||||
"(torch.Tensor or None expected)"
|
||||
.format(torch.typename(value), name))
|
||||
buffers[name] = value
|
||||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def ColoModulize(module):
|
||||
"""
|
||||
|
@ -64,7 +142,6 @@ def ColoModulize(module):
|
|||
module.colo_named_parameters = funcType(colo_named_parameters, module)
|
||||
module._colo_visited = True
|
||||
|
||||
|
||||
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
|
||||
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
|
||||
|
@ -77,11 +154,16 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self._lazy_memory_allocate = lazy_memory_allocate
|
||||
self._device = device
|
||||
|
||||
# TODO(jzy) replace it with old __setattr__ in the exit() of context?
|
||||
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
||||
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
FIXME(fjr) The module may be passed to this function multiple times?
|
||||
"""
|
||||
|
||||
if hasattr(module, '_colo_visited'):
|
||||
return
|
||||
|
||||
|
@ -100,7 +182,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
tensor_detached = param.to(self._device).detach()
|
||||
tensor_detached.requires_grad = requires_grad
|
||||
|
||||
setattr(module, name,
|
||||
ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload))
|
||||
colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)
|
||||
setattr(module, name, colo_param)
|
||||
|
||||
ColoModulize(module)
|
||||
ColoModulize(module)
|
|
@ -23,7 +23,7 @@ from transformers.file_utils import ModelOutput
|
|||
from dataclasses import fields
|
||||
|
||||
|
||||
def _post_init_colo(self):
|
||||
def _post_init_colotensor(self):
|
||||
class_fields = fields(self)
|
||||
# Safety and consistency checks
|
||||
if len(class_fields) == 0:
|
||||
|
@ -72,7 +72,7 @@ def _post_init_colo(self):
|
|||
self[field.name] = v
|
||||
|
||||
|
||||
ModelOutput.__post_init__ = _post_init_colo
|
||||
ModelOutput.__post_init__ = _post_init_colotensor
|
||||
# complete the hack
|
||||
|
||||
|
||||
|
@ -278,6 +278,26 @@ def test_colo_optimizer():
|
|||
if i > 5:
|
||||
break
|
||||
|
||||
def _test_pretrained():
|
||||
from _utils import check_equal
|
||||
from transformers import BertForMaskedLM
|
||||
set_seed(1)
|
||||
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
|
||||
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
|
||||
model_pretrained = model_pretrained.cuda()
|
||||
model = model.cuda()
|
||||
|
||||
dict_pretrained = {}
|
||||
dict_col = {}
|
||||
for name, param in model_pretrained.named_parameters():
|
||||
dict_pretrained[name] = param
|
||||
for name, param in model.named_parameters():
|
||||
dict_col[name] = param
|
||||
|
||||
for name, param in dict_pretrained.items():
|
||||
check_equal(param, dict_col[name])
|
||||
|
||||
def run_1d_row_tp(model_name: str):
|
||||
# A simple net with two stacked nn.Linear
|
||||
|
@ -377,4 +397,5 @@ def test_model(world_size):
|
|||
if __name__ == '__main__':
|
||||
# test_model_parameters()
|
||||
# test_colo_optimizer()
|
||||
test_model()
|
||||
# test_model()
|
||||
_test_pretrained()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from numpy import allclose
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.tensor import ColoTensor, ColoParameter
|
||||
from copy import deepcopy
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
@ -16,7 +16,7 @@ def test_layernorm():
|
|||
delattr(ln_op_colo, 'weight')
|
||||
weight_clone = ln_op.weight.clone().detach()
|
||||
weight_clone.requires_grad = True
|
||||
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))
|
||||
setattr(ln_op_colo, 'weight', ColoParameter.init_from_torch_tensor(tensor=weight_clone))
|
||||
|
||||
output = ln_op(input_t)
|
||||
output_colo = ln_op_colo(input_t_colo)
|
||||
|
@ -39,8 +39,8 @@ def test_linear():
|
|||
input_ref = torch.randn(1, in_dim)
|
||||
input_tensor = input_ref.clone()
|
||||
|
||||
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
|
||||
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
|
||||
sharded_weight = ColoParameter.init_from_torch_tensor(fc_ref.weight)
|
||||
sharded_bias = ColoParameter.init_from_torch_tensor(fc_ref.bias)
|
||||
|
||||
# replace the torch nn.Parameters with ShardedTensor
|
||||
delattr(fc, 'weight')
|
||||
|
|
Loading…
Reference in New Issue