ColossalAI/colossalai/nn/_ops/layernorm.py

25 lines
901 B
Python
Raw Normal View History

2022-04-25 03:49:20 +00:00
import torch
import torch.nn.functional as F
from typing import List, Optional
2022-04-25 03:49:20 +00:00
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, distspec
from ._utils import GeneralTensor, convert_to_colo_tensor
2022-04-25 03:49:20 +00:00
@colo_op_impl(F.layer_norm)
def colo_layernorm(
input_tensor: GeneralTensor,
normalized_shape: List[int],
weight: Optional[GeneralTensor] = None,
bias: Optional[GeneralTensor] = None,
eps: float = 1e-5,
):
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
2022-04-25 03:49:20 +00:00
# TODO (ver217): check dist spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.get_process_group()))
2022-04-25 03:49:20 +00:00
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec)
return output