Browse Source

[NFC] polish code style (#3268)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/3313/head
Yuanchen 2 years ago committed by binmakeswell
parent
commit
d58fa705b2
  1. 11
      colossalai/nn/_ops/_utils.py

11
colossalai/nn/_ops/_utils.py

@ -1,12 +1,11 @@
import torch
from typing import Union, Optional, List
from colossalai.tensor import ColoTensor
from typing import List, Optional, Union
import torch
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn.layer.utils import divide
from colossalai.tensor import ProcessGroup, ColoTensorSpec
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float]
@ -135,7 +134,7 @@ class _ReduceInput(torch.autograd.Function):
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_: input matrix.
process_group: parallel mode.

Loading…
Cancel
Save