[NFC] polish code style (#3268)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/3313/head
Yuanchen 2 years ago committed by binmakeswell
parent c4a226b729
commit d58fa705b2

@ -1,12 +1,11 @@
import torch from typing import List, Optional, Union
from typing import Union, Optional, List
from colossalai.tensor import ColoTensor
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.tensor import ProcessGroup, ColoTensorSpec from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
GeneralTensor = Union[ColoTensor, torch.Tensor] GeneralTensor = Union[ColoTensor, torch.Tensor]
Number = Union[int, float] Number = Union[int, float]
@ -135,7 +134,7 @@ class _ReduceInput(torch.autograd.Function):
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):
""" """
Split the input and keep only the corresponding chuck to the rank. Split the input and keep only the corresponding chuck to the rank.
Args: Args:
input_: input matrix. input_: input matrix.
process_group: parallel mode. process_group: parallel mode.

Loading…
Cancel
Save