|
|
|
@ -1,12 +1,11 @@
|
|
|
|
|
import torch
|
|
|
|
|
from typing import Union, Optional, List
|
|
|
|
|
from colossalai.tensor import ColoTensor
|
|
|
|
|
from typing import List, Optional, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
from colossalai.global_variables import tensor_parallel_env as env
|
|
|
|
|
|
|
|
|
|
from colossalai.global_variables import tensor_parallel_env as env
|
|
|
|
|
from colossalai.nn.layer.utils import divide
|
|
|
|
|
from colossalai.tensor import ProcessGroup, ColoTensorSpec
|
|
|
|
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
|
|
|
|
|
|
|
|
|
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
|
|
|
|
Number = Union[int, float]
|
|
|
|
@ -135,7 +134,7 @@ class _ReduceInput(torch.autograd.Function):
|
|
|
|
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
|
|
|
|
"""
|
|
|
|
|
Split the input and keep only the corresponding chuck to the rank.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input_: input matrix.
|
|
|
|
|
process_group: parallel mode.
|
|
|
|
|