2022-04-25 03:49:20 +00:00
|
|
|
import torch
|
|
|
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
2022-05-13 07:13:52 +00:00
|
|
|
from colossalai.tensor import ColoTensor, dist_spec
|
2022-04-25 03:49:20 +00:00
|
|
|
|
|
|
|
|
|
|
|
@colo_op_impl(torch.nn.functional.layer_norm)
|
|
|
|
def colo_layernorm(types, args=(), kwargs=None, pg=None):
|
|
|
|
arg_num = len(args)
|
|
|
|
if arg_num > 0:
|
|
|
|
input_tensor = args[0]
|
|
|
|
if arg_num > 1:
|
|
|
|
normalized_shape = args[1]
|
|
|
|
if arg_num > 2:
|
|
|
|
weight = args[3]
|
|
|
|
if arg_num > 3:
|
|
|
|
bias = args[4]
|
|
|
|
if arg_num > 4:
|
|
|
|
eps = args[5]
|
|
|
|
|
|
|
|
if 'input' in kwargs:
|
|
|
|
input_tensor = kwargs['input']
|
|
|
|
if 'weight' in kwargs:
|
|
|
|
weight = kwargs['weight']
|
|
|
|
if 'bias' in kwargs:
|
|
|
|
bias = kwargs['bias']
|
|
|
|
if 'eps' in kwargs:
|
|
|
|
eps = kwargs['eps']
|
|
|
|
|
|
|
|
if isinstance(input_tensor, ColoTensor):
|
2022-05-13 07:13:52 +00:00
|
|
|
# TODO (ver217): check input dist spec
|
|
|
|
input_tensor.to_dist_spec(dist_spec.replicate())
|
2022-04-25 03:49:20 +00:00
|
|
|
input_tensor = input_tensor.torch_tensor()
|
|
|
|
if isinstance(weight, ColoTensor):
|
|
|
|
weight = weight.torch_tensor()
|
|
|
|
if isinstance(bias, ColoTensor):
|
|
|
|
bias = bias.torch_tensor()
|
|
|
|
|
|
|
|
return ColoTensor.init_from_torch_tensor(
|
|
|
|
torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight, bias, eps))
|