|
|
|
@ -1,16 +1,18 @@
|
|
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
import operator
|
|
|
|
|
from functools import reduce
|
|
|
|
|
from typing import List, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
|
from functools import reduce
|
|
|
|
|
import operator
|
|
|
|
|
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
|
|
|
|
|
|
|
|
|
|
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
|
|
|
|
|
|
|
|
|
|
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
|
|
|
|
|
|
|
|
@ -260,7 +262,7 @@ def send_forward_recv_backward(output_tensor,
|
|
|
|
|
next_rank=None,
|
|
|
|
|
dtype=torch.float,
|
|
|
|
|
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
|
|
|
"""Batched communication operation. Sends the input tensor to the
|
|
|
|
|
"""Batched communication operation. Sends the input tensor to the
|
|
|
|
|
next stage in pipeline, while receives the gradient tensor from the
|
|
|
|
|
next stage in pipeline as the input gradient tensor of this stage.
|
|
|
|
|
|
|
|
|
@ -319,7 +321,7 @@ def send_forward_recv_forward(output_tensor,
|
|
|
|
|
next_rank=None,
|
|
|
|
|
dtype=torch.float,
|
|
|
|
|
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
|
|
|
"""Batched communication operation. Sends the input tensor to the
|
|
|
|
|
"""Batched communication operation. Sends the input tensor to the
|
|
|
|
|
next stage in pipeline, while receives the output tensor from the
|
|
|
|
|
previous stage in pipeline as the input of this stage.
|
|
|
|
|
|
|
|
|
|