mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
1.1 KiB
34 lines
1.1 KiB
from typing import Optional
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
|
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
|
|
|
|
|
@colo_op_impl(F.batch_norm)
|
|
def colo_batch_norm(
|
|
input: GeneralTensor,
|
|
running_mean: Optional[GeneralTensor],
|
|
running_var: Optional[GeneralTensor],
|
|
weight: Optional[GeneralTensor] = None,
|
|
bias: Optional[GeneralTensor] = None,
|
|
training: bool = False,
|
|
momentum: float = 0.1,
|
|
eps: float = 1e-5,
|
|
):
|
|
assert isinstance(weight, ColoTensor)
|
|
running_mean = running_mean.detach()
|
|
running_var = running_var.detach()
|
|
|
|
input = convert_to_colo_tensor(input, weight.get_process_group())
|
|
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
|
input = input.redistribute(ReplicaSpec())
|
|
bias = bias.redistribute(ReplicaSpec())
|
|
|
|
output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
|
|
output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group()))
|
|
return output
|