mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
25 lines
899 B
25 lines
899 B
3 years ago
|
import torch
|
||
3 years ago
|
import torch.nn.functional as F
|
||
|
from typing import List, Optional
|
||
3 years ago
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||
3 years ago
|
from colossalai.tensor import ColoTensor, distspec
|
||
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||
3 years ago
|
|
||
|
|
||
3 years ago
|
@colo_op_impl(F.layer_norm)
|
||
|
def colo_layernorm(
|
||
|
input_tensor: GeneralTensor,
|
||
|
normalized_shape: List[int],
|
||
|
weight: Optional[GeneralTensor] = None,
|
||
|
bias: Optional[GeneralTensor] = None,
|
||
|
eps: float = 1e-5,
|
||
|
):
|
||
|
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||
3 years ago
|
|
||
3 years ago
|
# TODO (ver217): check dist spec
|
||
|
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.spec.get_process_group()))
|
||
3 years ago
|
|
||
3 years ago
|
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||
|
output = ColoTensor.from_torch_tensor(output, input_tensor.spec)
|
||
|
return output
|