From 5a4a3b77d91e23d521768befaf311365113aadec Mon Sep 17 00:00:00 2001
From: Jiang Zhuo <58946806+jiangz17THU@users.noreply.github.com>
Date: Thu, 10 Mar 2022 17:15:59 +0800
Subject: [PATCH] fix format (#376)

---
 colossalai/nn/layer/parallel_1d/_operation.py | 45 +++++++++----------
 colossalai/nn/layer/parallel_1d/_utils.py     |  4 ++
 2 files changed, 26 insertions(+), 23 deletions(-)

diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py
index d6b851e92..ee52db237 100644
--- a/colossalai/nn/layer/parallel_1d/_operation.py
+++ b/colossalai/nn/layer/parallel_1d/_operation.py
@@ -7,7 +7,7 @@ except:
 
 
 class FusedLayerNormAffineFunction1D(torch.autograd.Function):
-  r"""
+    r"""
   Layernorm
 
   :param input: input maxtrix
@@ -20,27 +20,26 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
   :param eps: a value added to the denominator for numerical stability
   """
 
-  @staticmethod
-  def forward(ctx, input, weight, bias, normalized_shape, eps):
-    ctx.normalized_shape = normalized_shape
-    ctx.eps = eps
-    input_ = input.contiguous()
-    weight_ = weight.contiguous()
-    bias_ = bias.contiguous()
-    output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
-        input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
-    ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
-    return output
+    @staticmethod
+    def forward(ctx, input, weight, bias, normalized_shape, eps):
+        ctx.normalized_shape = normalized_shape
+        ctx.eps = eps
+        input_ = input.contiguous()
+        weight_ = weight.contiguous()
+        bias_ = bias.contiguous()
+        output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
+                                                                             bias_, ctx.eps)
+        ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
+        return output
 
+    @staticmethod
+    def backward(ctx, grad_output):
+        input_, weight_, bias_, mean, invvar = ctx.saved_tensors
+        grad_input = grad_weight = grad_bias = None
+        grad_input, grad_weight, grad_bias \
+          = fused_mix_prec_layer_norm_cuda.backward_affine(
+            grad_output.contiguous(), mean, invvar,
+            input_, ctx.normalized_shape,
+            weight_, bias_, ctx.eps)
 
-  @staticmethod
-  def backward(ctx, grad_output):
-    input_, weight_, bias_, mean, invvar = ctx.saved_tensors
-    grad_input = grad_weight = grad_bias = None
-    grad_input, grad_weight, grad_bias \
-      = fused_mix_prec_layer_norm_cuda.backward_affine(
-        grad_output.contiguous(), mean, invvar,
-        input_, ctx.normalized_shape,
-        weight_, bias_, ctx.eps)
-
-    return grad_input, grad_weight, grad_bias, None, None
\ No newline at end of file
+        return grad_input, grad_weight, grad_bias, None, None
diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py
index cc1967f11..a9cb0994d 100644
--- a/colossalai/nn/layer/parallel_1d/_utils.py
+++ b/colossalai/nn/layer/parallel_1d/_utils.py
@@ -81,6 +81,7 @@ class _ReduceGrad(torch.autograd.Function):
     :param input_: input matrix
     :param parallel_mode: parallel mode
     """
+
     @staticmethod
     def symbolic(graph, input_):
         return input_
@@ -102,6 +103,7 @@ class _ReduceInput(torch.autograd.Function):
     :param input_: input matrix
     :param parallel_mode: parallel mode
     """
+
     @staticmethod
     def symbolic(graph, input_):
         return _reduce(input_)
@@ -123,6 +125,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
     :param parallel_mode: parallel mode
     :param dim: dimension
     """
+
     @staticmethod
     def symbolic(graph, input_):
         return _split(input_)
@@ -146,6 +149,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
     :param parallel_mode: parallel mode
     :param dim: dimension
     """
+
     @staticmethod
     def symbolic(graph, input_):
         return _gather(input_)