diff --git a/colossalai/nn/_ops/layernorm.py b/colossalai/nn/_ops/layernorm.py index e3eef9b18..2b761b84e 100644 --- a/colossalai/nn/_ops/layernorm.py +++ b/colossalai/nn/_ops/layernorm.py @@ -19,9 +19,7 @@ def colo_layernorm( input_tensor = input_tensor.redistribute(ReplicaSpec()) output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) - output = ColoTensor.from_torch_tensor( - tensor=output, - spec=ColoTensorSpec( - pg=input_tensor.get_process_group(), - dist_attr=input_tensor.dist_spec)) + output = ColoTensor.from_torch_tensor(tensor=output, + spec=ColoTensorSpec(pg=input_tensor.get_process_group(), + dist_attr=input_tensor.dist_spec)) return output