[Tensor] add layer norm Op (#852)

pull/867/head
Jiarui Fang 2022-04-25 11:49:20 +08:00 committed by GitHub
parent a82da26f7e
commit 126ba573a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 8 deletions

View File

@ -1,3 +1,4 @@
from .init import colo_uniform from .init import colo_uniform
from .linear import colo_linear from .linear import colo_linear
from .element_wise import colo_mean from .element_wise import colo_mean
from .layernorm import colo_layernorm

View File

@ -5,8 +5,10 @@ from colossalai.tensor import ColoTensor
@colo_op_impl(torch.mean) @colo_op_impl(torch.mean)
def colo_mean(types, args=(), kwargs=None, pg=None): def colo_mean(types, args=(), kwargs=None, pg=None):
stateful_tensor = args[0] input_t = args[0]
return torch.mean(stateful_tensor.torch_tensor()) if isinstance(input_t, ColoTensor):
input_t = input_t.torch_tensor()
return ColoTensor.init_from_torch_tensor(torch.mean(input_t))
def register_elementwise_op(op): def register_elementwise_op(op):
@ -22,7 +24,7 @@ def register_elementwise_op(op):
# Validate types # Validate types
if not isinstance(input_tensor, ColoTensor): if not isinstance(input_tensor, ColoTensor):
raise TypeError("input needs to be a ColoTensor") raise TypeError("input needs to be a ColoTensor")
return op(input_tensor.torch_tensor()) return ColoTensor.init_from_torch_tensor(op(input_tensor.torch_tensor()))
register_elementwise_op(torch.nn.functional.gelu) register_elementwise_op(torch.nn.functional.gelu)

View File

@ -0,0 +1,38 @@
from numpy import isin, kaiser
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
@colo_op_impl(torch.nn.functional.layer_norm)
def colo_layernorm(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
if arg_num > 0:
input_tensor = args[0]
if arg_num > 1:
normalized_shape = args[1]
if arg_num > 2:
weight = args[3]
if arg_num > 3:
bias = args[4]
if arg_num > 4:
eps = args[5]
if 'input' in kwargs:
input_tensor = kwargs['input']
if 'weight' in kwargs:
weight = kwargs['weight']
if 'bias' in kwargs:
bias = kwargs['bias']
if 'eps' in kwargs:
eps = kwargs['eps']
if isinstance(input_tensor, ColoTensor):
input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor):
weight = weight.torch_tensor()
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
return ColoTensor.init_from_torch_tensor(
torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight, bias, eps))

View File

@ -8,6 +8,7 @@ from colossalai.context import ParallelMode
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
class ColoTensor(object): class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI """ Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute. 1. It contains a torch.Tensor as an attribute.
@ -145,3 +146,6 @@ class ColoTensor(object):
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs) return func(*args, **kwargs)
def backward(self, retain_graph: bool = False):
self._torch_tensor.backward(retain_graph=retain_graph)

View File

@ -1,7 +1,32 @@
from numpy import allclose, require from numpy import allclose
import torch import torch
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from copy import deepcopy from copy import deepcopy
from colossalai.utils import get_current_device
def test_layernorm():
ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
ln_op_colo = deepcopy(ln_op)
input_t = torch.randn(3, 2, device=get_current_device())
input_t_colo = ColoTensor.init_from_torch_tensor(tensor=input_t.clone().detach())
# prepare colossalai LN
delattr(ln_op_colo, 'weight')
weight_clone = ln_op.weight.clone().detach()
weight_clone.requires_grad = True
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))
output = ln_op(input_t)
output_colo = ln_op_colo(input_t_colo)
assert allclose(output_colo.torch_tensor().detach().cpu(), output.detach().cpu())
torch.mean(output).backward()
torch.mean(output_colo).backward()
assert allclose(ln_op.weight.grad.cpu(), ln_op_colo.weight.torch_tensor().grad.cpu())
def test_linear(): def test_linear():
@ -50,8 +75,8 @@ def test_element_wise():
t_ref = torch.randn(3, 5) t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone()) t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref) assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref)) assert allclose(torch.nn.functional.gelu(t).torch_tensor(), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref)) assert allclose(torch.nn.functional.relu(t).torch_tensor(), torch.nn.functional.relu(t_ref))
# Test a function not wrapped by # Test a function not wrapped by
@ -76,4 +101,5 @@ def check_all():
if __name__ == '__main__': if __name__ == '__main__':
test_lazy_init_tensor() # test_lazy_init_ptensor()
test_layernorm()