mirror of https://github.com/hpcaitech/ColossalAI
Fix/format (#4261)
* revise shardformer readme (#4246) * [example] add llama pretraining (#4257) * [NFC] polish colossalai/communication/p2p.py code style --------- Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Qianran Ma <qianranm@luchentech.com>pull/4338/head
parent
b366f1d99f
commit
86cf6aed5b
|
@ -1,16 +1,18 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
from functools import reduce
|
||||
import operator
|
||||
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
|
||||
|
||||
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
@ -260,7 +262,7 @@ def send_forward_recv_backward(output_tensor,
|
|||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the gradient tensor from the
|
||||
next stage in pipeline as the input gradient tensor of this stage.
|
||||
|
||||
|
@ -319,7 +321,7 @@ def send_forward_recv_forward(output_tensor,
|
|||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the output tensor from the
|
||||
previous stage in pipeline as the input of this stage.
|
||||
|
||||
|
|
Loading…
Reference in New Issue