mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] add layer norm Op (#852)
parent
a82da26f7e
commit
126ba573a8
|
@ -1,3 +1,4 @@
|
||||||
from .init import colo_uniform
|
from .init import colo_uniform
|
||||||
from .linear import colo_linear
|
from .linear import colo_linear
|
||||||
from .element_wise import colo_mean
|
from .element_wise import colo_mean
|
||||||
|
from .layernorm import colo_layernorm
|
|
@ -5,8 +5,10 @@ from colossalai.tensor import ColoTensor
|
||||||
|
|
||||||
@colo_op_impl(torch.mean)
|
@colo_op_impl(torch.mean)
|
||||||
def colo_mean(types, args=(), kwargs=None, pg=None):
|
def colo_mean(types, args=(), kwargs=None, pg=None):
|
||||||
stateful_tensor = args[0]
|
input_t = args[0]
|
||||||
return torch.mean(stateful_tensor.torch_tensor())
|
if isinstance(input_t, ColoTensor):
|
||||||
|
input_t = input_t.torch_tensor()
|
||||||
|
return ColoTensor.init_from_torch_tensor(torch.mean(input_t))
|
||||||
|
|
||||||
|
|
||||||
def register_elementwise_op(op):
|
def register_elementwise_op(op):
|
||||||
|
@ -22,7 +24,7 @@ def register_elementwise_op(op):
|
||||||
# Validate types
|
# Validate types
|
||||||
if not isinstance(input_tensor, ColoTensor):
|
if not isinstance(input_tensor, ColoTensor):
|
||||||
raise TypeError("input needs to be a ColoTensor")
|
raise TypeError("input needs to be a ColoTensor")
|
||||||
return op(input_tensor.torch_tensor())
|
return ColoTensor.init_from_torch_tensor(op(input_tensor.torch_tensor()))
|
||||||
|
|
||||||
|
|
||||||
register_elementwise_op(torch.nn.functional.gelu)
|
register_elementwise_op(torch.nn.functional.gelu)
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
from numpy import isin, kaiser
|
||||||
|
import torch
|
||||||
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
|
from colossalai.tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
|
@colo_op_impl(torch.nn.functional.layer_norm)
|
||||||
|
def colo_layernorm(types, args=(), kwargs=None, pg=None):
|
||||||
|
arg_num = len(args)
|
||||||
|
if arg_num > 0:
|
||||||
|
input_tensor = args[0]
|
||||||
|
if arg_num > 1:
|
||||||
|
normalized_shape = args[1]
|
||||||
|
if arg_num > 2:
|
||||||
|
weight = args[3]
|
||||||
|
if arg_num > 3:
|
||||||
|
bias = args[4]
|
||||||
|
if arg_num > 4:
|
||||||
|
eps = args[5]
|
||||||
|
|
||||||
|
if 'input' in kwargs:
|
||||||
|
input_tensor = kwargs['input']
|
||||||
|
if 'weight' in kwargs:
|
||||||
|
weight = kwargs['weight']
|
||||||
|
if 'bias' in kwargs:
|
||||||
|
bias = kwargs['bias']
|
||||||
|
if 'eps' in kwargs:
|
||||||
|
eps = kwargs['eps']
|
||||||
|
|
||||||
|
if isinstance(input_tensor, ColoTensor):
|
||||||
|
input_tensor = input_tensor.torch_tensor()
|
||||||
|
if isinstance(weight, ColoTensor):
|
||||||
|
weight = weight.torch_tensor()
|
||||||
|
if isinstance(bias, ColoTensor):
|
||||||
|
bias = bias.torch_tensor()
|
||||||
|
|
||||||
|
return ColoTensor.init_from_torch_tensor(
|
||||||
|
torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight, bias, eps))
|
|
@ -8,6 +8,7 @@ from colossalai.context import ParallelMode
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
|
||||||
|
|
||||||
class ColoTensor(object):
|
class ColoTensor(object):
|
||||||
""" Data Structure for Tensor in Colossal-AI
|
""" Data Structure for Tensor in Colossal-AI
|
||||||
1. It contains a torch.Tensor as an attribute.
|
1. It contains a torch.Tensor as an attribute.
|
||||||
|
@ -145,3 +146,6 @@ class ColoTensor(object):
|
||||||
|
|
||||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
def backward(self, retain_graph: bool = False):
|
||||||
|
self._torch_tensor.backward(retain_graph=retain_graph)
|
||||||
|
|
|
@ -1,7 +1,32 @@
|
||||||
from numpy import allclose, require
|
from numpy import allclose
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm():
|
||||||
|
ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
|
||||||
|
ln_op_colo = deepcopy(ln_op)
|
||||||
|
|
||||||
|
input_t = torch.randn(3, 2, device=get_current_device())
|
||||||
|
input_t_colo = ColoTensor.init_from_torch_tensor(tensor=input_t.clone().detach())
|
||||||
|
|
||||||
|
# prepare colossalai LN
|
||||||
|
delattr(ln_op_colo, 'weight')
|
||||||
|
weight_clone = ln_op.weight.clone().detach()
|
||||||
|
weight_clone.requires_grad = True
|
||||||
|
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))
|
||||||
|
|
||||||
|
output = ln_op(input_t)
|
||||||
|
output_colo = ln_op_colo(input_t_colo)
|
||||||
|
|
||||||
|
assert allclose(output_colo.torch_tensor().detach().cpu(), output.detach().cpu())
|
||||||
|
|
||||||
|
torch.mean(output).backward()
|
||||||
|
torch.mean(output_colo).backward()
|
||||||
|
|
||||||
|
assert allclose(ln_op.weight.grad.cpu(), ln_op_colo.weight.torch_tensor().grad.cpu())
|
||||||
|
|
||||||
|
|
||||||
def test_linear():
|
def test_linear():
|
||||||
|
@ -50,8 +75,8 @@ def test_element_wise():
|
||||||
t_ref = torch.randn(3, 5)
|
t_ref = torch.randn(3, 5)
|
||||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||||
assert torch.mean(t) == torch.mean(t_ref)
|
assert torch.mean(t) == torch.mean(t_ref)
|
||||||
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref))
|
assert allclose(torch.nn.functional.gelu(t).torch_tensor(), torch.nn.functional.gelu(t_ref))
|
||||||
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref))
|
assert allclose(torch.nn.functional.relu(t).torch_tensor(), torch.nn.functional.relu(t_ref))
|
||||||
|
|
||||||
|
|
||||||
# Test a function not wrapped by
|
# Test a function not wrapped by
|
||||||
|
@ -76,4 +101,5 @@ def check_all():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_lazy_init_tensor()
|
# test_lazy_init_ptensor()
|
||||||
|
test_layernorm()
|
||||||
|
|
Loading…
Reference in New Issue