* revise shardformer readme (#4246)

* [example] add llama pretraining (#4257)

* [NFC] polish colossalai/communication/p2p.py code style

---------

Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Qianran Ma <qianranm@luchentech.com>
pull/4338/head
Michelle 1 year ago committed by binmakeswell
parent b366f1d99f
commit 86cf6aed5b

@ -1,16 +1,18 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import operator
from functools import reduce
from typing import List, Tuple, Union
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from functools import reduce
import operator
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
TensorShape = Union[torch.Size, List[int], Tuple[int]]
@ -260,7 +262,7 @@ def send_forward_recv_backward(output_tensor,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage.
@ -319,7 +321,7 @@ def send_forward_recv_forward(output_tensor,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.

Loading…
Cancel
Save