|
|
|
@ -1,12 +1,11 @@
|
|
|
|
|
import torch |
|
|
|
|
from typing import Union, Optional, List |
|
|
|
|
from colossalai.tensor import ColoTensor |
|
|
|
|
from typing import List, Optional, Union |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
from colossalai.global_variables import tensor_parallel_env as env |
|
|
|
|
|
|
|
|
|
from colossalai.global_variables import tensor_parallel_env as env |
|
|
|
|
from colossalai.nn.layer.utils import divide |
|
|
|
|
from colossalai.tensor import ProcessGroup, ColoTensorSpec |
|
|
|
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup |
|
|
|
|
|
|
|
|
|
GeneralTensor = Union[ColoTensor, torch.Tensor] |
|
|
|
|
Number = Union[int, float] |
|
|
|
@ -135,7 +134,7 @@ class _ReduceInput(torch.autograd.Function):
|
|
|
|
|
class _SplitForwardGatherBackward(torch.autograd.Function): |
|
|
|
|
""" |
|
|
|
|
Split the input and keep only the corresponding chuck to the rank. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
input_: input matrix. |
|
|
|
|
process_group: parallel mode. |
|
|
|
|