mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] add from_pretrained support and bert pretrained test (#921)
* add from_pretrained support and test * polish * polish * polish * polishpull/922/head
parent
1d625fcd36
commit
c195d2814c
|
@ -1,7 +1,7 @@
|
||||||
from .op_wrapper import _COLOSSAL_OPS
|
from .op_wrapper import _COLOSSAL_OPS
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Tuple, Optional, Callable
|
from typing import Tuple, Optional, Callable, Union
|
||||||
from numpy import product
|
from numpy import product
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
|
@ -55,6 +55,15 @@ class ColoTensor(object):
|
||||||
def data(self):
|
def data(self):
|
||||||
return self._torch_tensor.data
|
return self._torch_tensor.data
|
||||||
|
|
||||||
|
@data.setter
|
||||||
|
def data(self, tensor: Union[torch.Tensor, "ColoTensor"]):
|
||||||
|
if isinstance(tensor, ColoTensor):
|
||||||
|
self._torch_tensor.data = tensor.data
|
||||||
|
elif isinstance(tensor, torch.Tensor):
|
||||||
|
self._torch_tensor.data = tensor
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def grad(self):
|
def grad(self):
|
||||||
return self._torch_tensor.grad
|
return self._torch_tensor.grad
|
||||||
|
@ -148,14 +157,31 @@ class ColoTensor(object):
|
||||||
assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
|
assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
|
||||||
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
|
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
|
||||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
||||||
if self._shard_pattern == ShardPattern.Row:
|
dim = self._get_gather_dim()
|
||||||
dim = 0
|
|
||||||
elif self._shard_pattern == ShardPattern.Col:
|
|
||||||
dim = -1
|
|
||||||
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
|
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
|
||||||
self._shard_pattern = ShardPattern.NA
|
self._shard_pattern = ShardPattern.NA
|
||||||
self._size = self._torch_tensor.size()
|
self._size = self._torch_tensor.size()
|
||||||
|
|
||||||
|
def global_torch_tensor(self) -> torch.Tensor:
|
||||||
|
out_tensor = self.torch_tensor()
|
||||||
|
if self.is_gathered():
|
||||||
|
return out_tensor
|
||||||
|
|
||||||
|
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
||||||
|
world_size = gpc.get_world_size(parallel_action.parallel_mode)
|
||||||
|
if world_size == 1:
|
||||||
|
return out_tensor
|
||||||
|
|
||||||
|
rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||||
|
tensor_list = [torch.empty_like(out_tensor) for _ in range(world_size)]
|
||||||
|
tensor_list[rank] = out_tensor
|
||||||
|
torch.distributed.all_gather(tensor_list, out_tensor, group=gpc.get_group(parallel_action.parallel_mode))
|
||||||
|
|
||||||
|
dim = self._get_gather_dim()
|
||||||
|
out_tensor = torch.cat(tensor_list, dim=dim).contiguous()
|
||||||
|
|
||||||
|
return out_tensor
|
||||||
|
|
||||||
def is_gathered(self) -> bool:
|
def is_gathered(self) -> bool:
|
||||||
return self._shard_pattern == ShardPattern.NA
|
return self._shard_pattern == ShardPattern.NA
|
||||||
|
|
||||||
|
@ -212,9 +238,7 @@ class ColoTensor(object):
|
||||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
|
|
||||||
def replace_tensor_with_colo(func):
|
def replace_tensor_with_colo(func):
|
||||||
|
|
||||||
def execute_func(*args, **kwargs):
|
def execute_func(*args, **kwargs):
|
||||||
# transform the ColoTensor args to torch Tensor.
|
# transform the ColoTensor args to torch Tensor.
|
||||||
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
||||||
|
@ -225,7 +249,9 @@ class ColoTensor(object):
|
||||||
|
|
||||||
return execute_func
|
return execute_func
|
||||||
|
|
||||||
assert hasattr(self._torch_tensor, name), f"torch.Tensor has not attribute named as {name}. So is ColoTensor"
|
if hasattr(self._torch_tensor, name) == False:
|
||||||
|
raise AttributeError
|
||||||
|
|
||||||
attr = getattr(self._torch_tensor, name)
|
attr = getattr(self._torch_tensor, name)
|
||||||
|
|
||||||
if isinstance(attr, Callable):
|
if isinstance(attr, Callable):
|
||||||
|
@ -244,3 +270,12 @@ class ColoTensor(object):
|
||||||
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
|
ColoTensor.init_from_torch_tensor(output) if type(output) is torch.Tensor else output
|
||||||
for output in outputs
|
for output in outputs
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def _get_gather_dim(self):
|
||||||
|
if self._shard_pattern == ShardPattern.Row:
|
||||||
|
dim = 0
|
||||||
|
elif self._shard_pattern == ShardPattern.Col:
|
||||||
|
dim = -1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return dim
|
||||||
|
|
|
@ -4,8 +4,86 @@ from colossalai.tensor import ColoTensor, ColoParameter
|
||||||
import types
|
import types
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Iterator, Tuple, Union
|
from typing import Iterator, Tuple, Union, Optional
|
||||||
|
|
||||||
|
# Adapted from torch.nn.module.Module.register_param
|
||||||
|
def _register_parameter_with_colotensor(self, name: str, param):
|
||||||
|
if '_parameters' not in self.__dict__:
|
||||||
|
raise AttributeError(
|
||||||
|
"cannot assign parameter before Module.__init__() call")
|
||||||
|
|
||||||
|
if not isinstance(name, torch._six.string_classes):
|
||||||
|
raise TypeError("parameter name should be a string. "
|
||||||
|
"Got {}".format(torch.typename(name)))
|
||||||
|
if '.' in name:
|
||||||
|
raise KeyError("parameter name can't contain \".\"")
|
||||||
|
if name == '':
|
||||||
|
raise KeyError("parameter name can't be empty string \"\"")
|
||||||
|
if hasattr(self, name) and name not in self._parameters:
|
||||||
|
raise KeyError("attribute '{}' already exists".format(name))
|
||||||
|
|
||||||
|
if param is None:
|
||||||
|
self._parameters[name] = None
|
||||||
|
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
|
||||||
|
raise TypeError("cannot assign '{}' object to parameter '{}' "
|
||||||
|
"(torch.nn.Parameter or ColoParameter or None required)"
|
||||||
|
.format(torch.typename(param), name))
|
||||||
|
elif param.grad_fn:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
|
||||||
|
"parameters must be created explicitly. To express '{0}' "
|
||||||
|
"as a function of another Tensor, compute the value in "
|
||||||
|
"the forward() method.".format(name))
|
||||||
|
else:
|
||||||
|
self._parameters[name] = param
|
||||||
|
|
||||||
|
# Adapted from torch.nn.module.Module.__setattr__
|
||||||
|
def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):
|
||||||
|
def remove_from(*dicts_or_sets):
|
||||||
|
for d in dicts_or_sets:
|
||||||
|
if name in d:
|
||||||
|
if isinstance(d, dict):
|
||||||
|
del d[name]
|
||||||
|
else:
|
||||||
|
d.discard(name)
|
||||||
|
|
||||||
|
params = self.__dict__.get('_parameters')
|
||||||
|
if isinstance(value, (ColoTensor, torch.nn.Parameter)):
|
||||||
|
if params is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"cannot assign parameters before Module.__init__() call")
|
||||||
|
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
|
||||||
|
self.register_parameter(name, value)
|
||||||
|
elif params is not None and name in params:
|
||||||
|
if value is not None:
|
||||||
|
raise TypeError("cannot assign '{}' as parameter '{}' "
|
||||||
|
"(torch.nn.Parameter or None expected)"
|
||||||
|
.format(torch.typename(value), name))
|
||||||
|
self.register_parameter(name, value)
|
||||||
|
else:
|
||||||
|
modules = self.__dict__.get('_modules')
|
||||||
|
if isinstance(value, torch.nn.Module):
|
||||||
|
if modules is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"cannot assign module before Module.__init__() call")
|
||||||
|
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
|
||||||
|
modules[name] = value
|
||||||
|
elif modules is not None and name in modules:
|
||||||
|
if value is not None:
|
||||||
|
raise TypeError("cannot assign '{}' as child module '{}' "
|
||||||
|
"(torch.nn.Module or None expected)"
|
||||||
|
.format(torch.typename(value), name))
|
||||||
|
modules[name] = value
|
||||||
|
else:
|
||||||
|
buffers = self.__dict__.get('_buffers')
|
||||||
|
if buffers is not None and name in buffers:
|
||||||
|
if value is not None and not isinstance(value, torch.Tensor):
|
||||||
|
raise TypeError("cannot assign '{}' as buffer '{}' "
|
||||||
|
"(torch.Tensor or None expected)"
|
||||||
|
.format(torch.typename(value), name))
|
||||||
|
buffers[name] = value
|
||||||
|
else:
|
||||||
|
object.__setattr__(self, name, value)
|
||||||
|
|
||||||
def ColoModulize(module):
|
def ColoModulize(module):
|
||||||
"""
|
"""
|
||||||
|
@ -64,7 +142,6 @@ def ColoModulize(module):
|
||||||
module.colo_named_parameters = funcType(colo_named_parameters, module)
|
module.colo_named_parameters = funcType(colo_named_parameters, module)
|
||||||
module._colo_visited = True
|
module._colo_visited = True
|
||||||
|
|
||||||
|
|
||||||
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
|
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
|
||||||
|
@ -77,11 +154,16 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
self._lazy_memory_allocate = lazy_memory_allocate
|
self._lazy_memory_allocate = lazy_memory_allocate
|
||||||
self._device = device
|
self._device = device
|
||||||
|
|
||||||
|
# TODO(jzy) replace it with old __setattr__ in the exit() of context?
|
||||||
|
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
||||||
|
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
||||||
|
|
||||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
The function to call at the end of the constructor of each module.
|
The function to call at the end of the constructor of each module.
|
||||||
FIXME(fjr) The module may be passed to this function multiple times?
|
FIXME(fjr) The module may be passed to this function multiple times?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if hasattr(module, '_colo_visited'):
|
if hasattr(module, '_colo_visited'):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -100,7 +182,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
tensor_detached = param.to(self._device).detach()
|
tensor_detached = param.to(self._device).detach()
|
||||||
tensor_detached.requires_grad = requires_grad
|
tensor_detached.requires_grad = requires_grad
|
||||||
|
|
||||||
setattr(module, name,
|
colo_param = ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload)
|
||||||
ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload))
|
setattr(module, name, colo_param)
|
||||||
|
|
||||||
ColoModulize(module)
|
ColoModulize(module)
|
|
@ -23,7 +23,7 @@ from transformers.file_utils import ModelOutput
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
|
|
||||||
|
|
||||||
def _post_init_colo(self):
|
def _post_init_colotensor(self):
|
||||||
class_fields = fields(self)
|
class_fields = fields(self)
|
||||||
# Safety and consistency checks
|
# Safety and consistency checks
|
||||||
if len(class_fields) == 0:
|
if len(class_fields) == 0:
|
||||||
|
@ -72,7 +72,7 @@ def _post_init_colo(self):
|
||||||
self[field.name] = v
|
self[field.name] = v
|
||||||
|
|
||||||
|
|
||||||
ModelOutput.__post_init__ = _post_init_colo
|
ModelOutput.__post_init__ = _post_init_colotensor
|
||||||
# complete the hack
|
# complete the hack
|
||||||
|
|
||||||
|
|
||||||
|
@ -278,6 +278,26 @@ def test_colo_optimizer():
|
||||||
if i > 5:
|
if i > 5:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def _test_pretrained():
|
||||||
|
from _utils import check_equal
|
||||||
|
from transformers import BertForMaskedLM
|
||||||
|
set_seed(1)
|
||||||
|
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||||
|
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
|
||||||
|
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||||
|
|
||||||
|
model_pretrained = model_pretrained.cuda()
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
dict_pretrained = {}
|
||||||
|
dict_col = {}
|
||||||
|
for name, param in model_pretrained.named_parameters():
|
||||||
|
dict_pretrained[name] = param
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
dict_col[name] = param
|
||||||
|
|
||||||
|
for name, param in dict_pretrained.items():
|
||||||
|
check_equal(param, dict_col[name])
|
||||||
|
|
||||||
def run_1d_row_tp(model_name: str):
|
def run_1d_row_tp(model_name: str):
|
||||||
# A simple net with two stacked nn.Linear
|
# A simple net with two stacked nn.Linear
|
||||||
|
@ -377,4 +397,5 @@ def test_model(world_size):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_model_parameters()
|
# test_model_parameters()
|
||||||
# test_colo_optimizer()
|
# test_colo_optimizer()
|
||||||
test_model()
|
# test_model()
|
||||||
|
_test_pretrained()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from numpy import allclose
|
from numpy import allclose
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor, ColoParameter
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ def test_layernorm():
|
||||||
delattr(ln_op_colo, 'weight')
|
delattr(ln_op_colo, 'weight')
|
||||||
weight_clone = ln_op.weight.clone().detach()
|
weight_clone = ln_op.weight.clone().detach()
|
||||||
weight_clone.requires_grad = True
|
weight_clone.requires_grad = True
|
||||||
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))
|
setattr(ln_op_colo, 'weight', ColoParameter.init_from_torch_tensor(tensor=weight_clone))
|
||||||
|
|
||||||
output = ln_op(input_t)
|
output = ln_op(input_t)
|
||||||
output_colo = ln_op_colo(input_t_colo)
|
output_colo = ln_op_colo(input_t_colo)
|
||||||
|
@ -39,8 +39,8 @@ def test_linear():
|
||||||
input_ref = torch.randn(1, in_dim)
|
input_ref = torch.randn(1, in_dim)
|
||||||
input_tensor = input_ref.clone()
|
input_tensor = input_ref.clone()
|
||||||
|
|
||||||
sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight)
|
sharded_weight = ColoParameter.init_from_torch_tensor(fc_ref.weight)
|
||||||
sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias)
|
sharded_bias = ColoParameter.init_from_torch_tensor(fc_ref.bias)
|
||||||
|
|
||||||
# replace the torch nn.Parameters with ShardedTensor
|
# replace the torch nn.Parameters with ShardedTensor
|
||||||
delattr(fc, 'weight')
|
delattr(fc, 'weight')
|
||||||
|
|
Loading…
Reference in New Issue