2023-06-28 09:12:19 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
import io
|
|
|
|
import pickle
|
2023-08-08 09:46:44 +00:00
|
|
|
import re
|
2023-11-16 12:15:59 +00:00
|
|
|
from collections import namedtuple
|
2024-01-08 07:37:27 +00:00
|
|
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
2023-06-28 09:12:19 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
2023-07-25 16:53:57 +00:00
|
|
|
from packaging.version import Version
|
2023-06-28 09:12:19 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from torch.distributed import distributed_c10d as c10d
|
2024-01-08 07:37:27 +00:00
|
|
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
2023-06-28 09:12:19 +00:00
|
|
|
|
|
|
|
from .stage_manager import PipelineStageManager
|
|
|
|
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any:
|
2023-06-28 09:12:19 +00:00
|
|
|
"""transform tensor to object with unpickle.
|
|
|
|
Info of the device in bytes stream will be modified into current device before unpickling
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tensor (:class:`torch.tensor`): tensor to be unpickled
|
|
|
|
tensor_size (:class:`torch.Size`): Size of the real info in bytes
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: object after unpickled
|
|
|
|
"""
|
|
|
|
buf = tensor.numpy().tobytes()[:tensor_size]
|
2023-09-19 06:20:26 +00:00
|
|
|
if b"cuda" in buf:
|
2023-06-28 09:12:19 +00:00
|
|
|
buf_array = bytearray(buf)
|
|
|
|
device_index = torch.cuda.current_device()
|
2023-08-08 09:46:44 +00:00
|
|
|
# There might be more than one output tensors during forward
|
2023-09-19 06:20:26 +00:00
|
|
|
for cuda_str in re.finditer(b"cuda", buf_array):
|
2023-08-08 09:46:44 +00:00
|
|
|
pos = cuda_str.start()
|
|
|
|
buf_array[pos + 5] = 48 + device_index
|
2023-06-28 09:12:19 +00:00
|
|
|
buf = bytes(buf_array)
|
|
|
|
|
|
|
|
io_bytes = io.BytesIO(buf)
|
2024-01-08 07:37:27 +00:00
|
|
|
byte_pickler = pickle.Unpickler(io_bytes)
|
2023-06-28 09:12:19 +00:00
|
|
|
unpickle = byte_pickler.load()
|
|
|
|
|
|
|
|
return unpickle
|
|
|
|
|
|
|
|
|
2023-12-21 09:01:01 +00:00
|
|
|
def check_for_nccl_backend(group):
|
|
|
|
pg = group or c10d._get_default_group()
|
|
|
|
# Gate PG wrapper check on Gloo availability.
|
|
|
|
if c10d._GLOO_AVAILABLE:
|
|
|
|
# It is not expected for PG to be wrapped many times, but support it just
|
|
|
|
# in case
|
|
|
|
while isinstance(pg, c10d._ProcessGroupWrapper):
|
|
|
|
pg = pg.wrapped_pg
|
|
|
|
|
|
|
|
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
|
|
|
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
|
2023-09-19 06:20:26 +00:00
|
|
|
def _broadcast_object_list(
|
|
|
|
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
|
|
|
):
|
2023-06-28 09:12:19 +00:00
|
|
|
"""This is a modified version of the broadcast_object_list in torch.distribution
|
|
|
|
The only difference is that object will be move to correct device after unpickled.
|
|
|
|
If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
|
|
|
|
be updated with data sent from rank src.
|
|
|
|
Args:
|
|
|
|
object_list (List[Any]): list of object to broadcast
|
|
|
|
src (int): source rank to broadcast
|
|
|
|
dst (int): dst rank to broadcast
|
|
|
|
device (:class:`torch.device`): device to do broadcast. current device in default
|
|
|
|
"""
|
|
|
|
|
|
|
|
if c10d._rank_not_in_group(group):
|
|
|
|
c10d._warn_not_in_group("broadcast_object_list")
|
|
|
|
return
|
|
|
|
|
2024-01-08 07:37:27 +00:00
|
|
|
is_nccl_backend = _check_for_nccl_backend(group)
|
2023-06-28 09:12:19 +00:00
|
|
|
current_device = None
|
|
|
|
|
|
|
|
if device is not None:
|
|
|
|
if is_nccl_backend and device.type != "cuda":
|
|
|
|
raise ValueError("device type must be cuda for nccl backend")
|
|
|
|
current_device = device
|
|
|
|
else:
|
|
|
|
current_device = torch.device("cpu")
|
|
|
|
if is_nccl_backend:
|
|
|
|
current_device = torch.device("cuda", torch.cuda.current_device())
|
2023-07-25 06:31:21 +00:00
|
|
|
|
|
|
|
my_rank = dist.get_rank()
|
|
|
|
# Serialize object_list elements to tensors on src rank.
|
|
|
|
if my_rank == src:
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
if Version(torch.__version__) >= Version("2.3.0"):
|
|
|
|
tensor_list, size_list = zip(
|
|
|
|
*[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list]
|
|
|
|
)
|
|
|
|
elif Version(torch.__version__) >= Version("1.13.0"):
|
2023-07-25 06:31:21 +00:00
|
|
|
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
|
|
|
|
else:
|
|
|
|
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
|
|
|
object_sizes_tensor = torch.cat(size_list)
|
|
|
|
else:
|
|
|
|
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
|
|
|
|
|
2023-06-28 09:12:19 +00:00
|
|
|
if is_nccl_backend:
|
|
|
|
object_sizes_tensor = object_sizes_tensor.to(current_device)
|
|
|
|
|
|
|
|
# Broadcast object sizes
|
|
|
|
c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False)
|
|
|
|
|
|
|
|
# Concatenate and broadcast serialized object tensors
|
|
|
|
if my_rank == src:
|
|
|
|
object_tensor = torch.cat(tensor_list)
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
object_tensor = torch.empty( # type: ignore[call-overload]
|
|
|
|
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
|
2023-06-28 09:12:19 +00:00
|
|
|
dtype=torch.uint8,
|
|
|
|
)
|
|
|
|
|
|
|
|
if is_nccl_backend:
|
|
|
|
object_tensor = object_tensor.to(current_device)
|
|
|
|
|
|
|
|
c10d.broadcast(object_tensor, src=src, group=group, async_op=False)
|
|
|
|
|
|
|
|
# Deserialize objects using their stored sizes.
|
|
|
|
offset = 0
|
|
|
|
|
|
|
|
if my_rank != src:
|
|
|
|
for i, obj_size in enumerate(object_sizes_tensor):
|
2023-12-22 02:44:00 +00:00
|
|
|
obj_view = object_tensor[offset : offset + obj_size]
|
2023-06-28 09:12:19 +00:00
|
|
|
obj_view = obj_view.type(torch.uint8)
|
|
|
|
if obj_view.device != torch.device("cpu"):
|
|
|
|
obj_view = obj_view.cpu()
|
|
|
|
offset += obj_size
|
|
|
|
# unpickle
|
|
|
|
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)
|
|
|
|
|
|
|
|
# unconsistence in device
|
2023-09-19 06:20:26 +00:00
|
|
|
if (
|
|
|
|
isinstance(unpickle_object, torch.Tensor)
|
|
|
|
and unpickle_object.device.index != torch.cuda.current_device()
|
|
|
|
):
|
2023-06-28 09:12:19 +00:00
|
|
|
unpickle_object = unpickle_object.cuda()
|
|
|
|
|
|
|
|
object_list[i] = unpickle_object
|
|
|
|
|
|
|
|
|
2024-01-08 07:37:27 +00:00
|
|
|
def _check_for_nccl_backend(group):
|
2023-12-22 02:44:00 +00:00
|
|
|
pg = group or c10d._get_default_group()
|
|
|
|
# Gate PG wrapper check on Gloo availability.
|
|
|
|
if c10d._GLOO_AVAILABLE:
|
2024-01-08 07:37:27 +00:00
|
|
|
# It is not expected for PG to be wrapped many times, but support it just in case
|
2023-12-22 02:44:00 +00:00
|
|
|
while isinstance(pg, c10d._ProcessGroupWrapper):
|
|
|
|
pg = pg.wrapped_pg
|
|
|
|
|
|
|
|
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
|
|
|
|
|
|
|
|
|
2024-01-08 07:37:27 +00:00
|
|
|
def _check_device(group):
|
|
|
|
is_nccl_backend = _check_for_nccl_backend(group)
|
2023-11-16 12:15:59 +00:00
|
|
|
current_device = torch.device("cpu")
|
|
|
|
if is_nccl_backend:
|
|
|
|
current_device = torch.device("cuda", torch.cuda.current_device())
|
|
|
|
return current_device, is_nccl_backend
|
|
|
|
|
|
|
|
|
2024-01-08 07:37:27 +00:00
|
|
|
TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
|
|
|
|
P2PMetadata = namedtuple("P2PMetadata", ["tree_spec", "tensor_metadata", "non_tensor_obj_idx", "non_tensor_objs"])
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
|
2024-01-08 07:37:27 +00:00
|
|
|
def create_send_metadata(
|
|
|
|
object: Any, strict: bool = True, return_tensor: bool = False
|
|
|
|
) -> Union[P2PMetadata, Tuple[P2PMetadata, List[torch.Tensor]]]:
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
object (Any): object needed to be sent
|
|
|
|
strict (bool, optional): whether to check if the object is supported for fast send
|
|
|
|
return_tensor (bool, optional): whether to return tensor objects
|
|
|
|
"""
|
|
|
|
objs, tree_spec = tree_flatten(object)
|
|
|
|
tensor_metadata, tensor_objs = [], []
|
|
|
|
non_tensor_obj_idx, non_tensor_objs = [], []
|
|
|
|
for idx, obj in enumerate(objs):
|
|
|
|
if isinstance(obj, torch.Tensor):
|
|
|
|
tensor_objs.append(obj)
|
|
|
|
tensor_metadata.append(TensorMetadata(obj.shape, obj.dtype, obj.requires_grad))
|
|
|
|
else:
|
|
|
|
non_tensor_obj_idx.append(idx)
|
|
|
|
non_tensor_objs.append(obj)
|
2023-11-16 12:15:59 +00:00
|
|
|
|
2024-01-08 07:37:27 +00:00
|
|
|
assert not strict or len(non_tensor_objs) == 0, "Only support tensor for fast send"
|
|
|
|
metadata = P2PMetadata(tree_spec, tensor_metadata, non_tensor_obj_idx, non_tensor_objs)
|
|
|
|
return metadata if not return_tensor else (metadata, tensor_objs)
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
|
2024-01-08 07:37:27 +00:00
|
|
|
def _filling_ops_queue(
|
|
|
|
obj: Union[torch.Tensor, List[torch.Tensor]],
|
|
|
|
comm_op: Callable,
|
|
|
|
comm_rank: int,
|
|
|
|
ops_queue: List,
|
|
|
|
group: ProcessGroup,
|
|
|
|
):
|
2023-11-16 12:15:59 +00:00
|
|
|
if isinstance(obj, torch.Tensor):
|
|
|
|
obj = obj.contiguous()
|
|
|
|
op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
|
|
|
|
ops_queue.append(op_to_add)
|
|
|
|
else:
|
|
|
|
for tensor_to_comm in obj:
|
2023-12-22 02:44:00 +00:00
|
|
|
assert isinstance(tensor_to_comm, torch.Tensor)
|
2024-01-08 07:37:27 +00:00
|
|
|
_filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group)
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
|
2024-01-08 07:37:27 +00:00
|
|
|
def _create_recv_buffer(tensor_metadata: List[TensorMetadata], current_device) -> List[torch.Tensor]:
|
|
|
|
buffer_recv = []
|
|
|
|
for metadata in tensor_metadata:
|
2023-12-22 02:44:00 +00:00
|
|
|
tensor_recv = torch.empty(
|
|
|
|
metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype
|
|
|
|
)
|
2024-01-08 07:37:27 +00:00
|
|
|
buffer_recv.append(tensor_recv)
|
|
|
|
return buffer_recv
|
2023-12-22 02:44:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _batch_send_recv_tensor(
|
2024-01-08 07:37:27 +00:00
|
|
|
send_tensor_list: Optional[List[torch.Tensor]],
|
|
|
|
recv_tensor_metadata: Optional[List[TensorMetadata]],
|
2023-12-22 02:44:00 +00:00
|
|
|
send_dst: Optional[int],
|
|
|
|
recv_src: Optional[int],
|
|
|
|
send_group: Optional[ProcessGroup],
|
|
|
|
recv_group: Optional[ProcessGroup],
|
|
|
|
current_device: Any,
|
2024-06-26 06:48:02 +00:00
|
|
|
overlap_p2p: bool = True,
|
|
|
|
send_first: bool = True,
|
2023-12-22 02:44:00 +00:00
|
|
|
) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]:
|
2023-11-16 12:15:59 +00:00
|
|
|
buffer_recv = None
|
2024-01-08 07:37:27 +00:00
|
|
|
if recv_tensor_metadata is not None:
|
|
|
|
buffer_recv = _create_recv_buffer(recv_tensor_metadata, current_device)
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
ops = []
|
2024-06-26 06:48:02 +00:00
|
|
|
is_send = send_dst is not None and send_tensor_list is not None
|
|
|
|
is_recv = recv_src is not None and buffer_recv is not None
|
|
|
|
|
|
|
|
if send_first:
|
|
|
|
if is_send:
|
|
|
|
assert send_group is not None
|
|
|
|
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
|
|
|
if is_recv:
|
|
|
|
assert recv_group is not None
|
|
|
|
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
|
|
|
else:
|
|
|
|
if is_recv:
|
|
|
|
assert recv_group is not None
|
|
|
|
_filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
|
|
|
|
if is_send:
|
|
|
|
assert send_group is not None
|
|
|
|
_filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
if len(ops) > 0:
|
|
|
|
reqs = dist.batch_isend_irecv(ops)
|
2024-06-26 06:48:02 +00:00
|
|
|
if not overlap_p2p:
|
|
|
|
for req in reqs:
|
|
|
|
req.wait()
|
|
|
|
return buffer_recv, []
|
|
|
|
else:
|
|
|
|
return buffer_recv, reqs
|
|
|
|
return None, []
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _send_recv_serialization_object(
|
2024-01-08 07:37:27 +00:00
|
|
|
object: Optional[P2PMetadata],
|
2023-12-22 02:44:00 +00:00
|
|
|
send_dst: Optional[int],
|
|
|
|
recv_src: Optional[int],
|
|
|
|
send_group: Optional[ProcessGroup],
|
|
|
|
recv_group: Optional[ProcessGroup],
|
|
|
|
current_device: Any,
|
|
|
|
is_nccl_backend: bool,
|
2024-06-26 06:48:02 +00:00
|
|
|
send_first: bool = True,
|
2023-12-22 02:44:00 +00:00
|
|
|
) -> Optional[P2PMetadata]:
|
2023-11-16 12:15:59 +00:00
|
|
|
ops = []
|
|
|
|
send_object_tensor = None
|
2024-06-26 06:48:02 +00:00
|
|
|
send_object_size_tensor = None
|
2023-11-16 12:15:59 +00:00
|
|
|
if object is not None and send_dst is not None:
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
if Version(torch.__version__) >= Version("2.3.0"):
|
|
|
|
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(
|
|
|
|
object, device=current_device, group=send_group
|
|
|
|
)
|
|
|
|
elif Version(torch.__version__) >= Version("1.13.0"):
|
2023-11-16 12:15:59 +00:00
|
|
|
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
|
|
|
|
else:
|
|
|
|
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)
|
|
|
|
|
|
|
|
if is_nccl_backend:
|
|
|
|
send_object_size_tensor = send_object_size_tensor.to(current_device)
|
|
|
|
send_object_tensor = send_object_tensor.to(current_device)
|
|
|
|
|
|
|
|
recv_object_size_tensor = None
|
|
|
|
if recv_src is not None:
|
|
|
|
recv_object_size_tensor = torch.empty(1, dtype=torch.long)
|
|
|
|
if is_nccl_backend:
|
|
|
|
recv_object_size_tensor = recv_object_size_tensor.to(current_device)
|
2024-06-26 06:48:02 +00:00
|
|
|
|
|
|
|
if send_first:
|
|
|
|
if send_object_size_tensor is not None:
|
|
|
|
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
|
|
|
if recv_src is not None:
|
|
|
|
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
|
|
|
else:
|
|
|
|
if recv_src is not None:
|
|
|
|
_filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
|
|
|
|
if send_object_size_tensor is not None:
|
|
|
|
_filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
if len(ops) > 0:
|
|
|
|
reqs = dist.batch_isend_irecv(ops)
|
|
|
|
for req in reqs:
|
2024-06-26 06:48:02 +00:00
|
|
|
req.wait() # This blocks the compute stream in torch
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
ops = []
|
2024-06-26 06:48:02 +00:00
|
|
|
is_send = send_dst is not None and send_object_tensor is not None
|
|
|
|
is_recv = recv_src is not None and recv_object_size_tensor is not None
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
recv_object_tensor = None
|
2024-06-26 06:48:02 +00:00
|
|
|
if is_recv:
|
2023-11-16 12:15:59 +00:00
|
|
|
recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
|
|
|
|
if is_nccl_backend:
|
|
|
|
recv_object_tensor = recv_object_tensor.to(current_device)
|
2024-06-26 06:48:02 +00:00
|
|
|
|
|
|
|
if send_first:
|
|
|
|
if is_send:
|
|
|
|
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
|
|
|
if is_recv:
|
|
|
|
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
|
|
|
else:
|
|
|
|
if is_recv:
|
|
|
|
_filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
|
|
|
|
if is_send:
|
|
|
|
_filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
if len(ops) > 0:
|
|
|
|
reqs = dist.batch_isend_irecv(ops)
|
|
|
|
for req in reqs:
|
|
|
|
req.wait()
|
|
|
|
|
|
|
|
if recv_object_tensor is not None and recv_object_size_tensor is not None:
|
|
|
|
recv_object_tensor = recv_object_tensor.type(torch.uint8)
|
|
|
|
if recv_object_tensor.device != torch.device("cpu"):
|
|
|
|
recv_object_tensor = recv_object_tensor.cpu()
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
|
2023-11-16 12:15:59 +00:00
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
|
2023-11-16 12:15:59 +00:00
|
|
|
unpickle_object = unpickle_object.cuda()
|
|
|
|
|
|
|
|
return unpickle_object
|
|
|
|
|
|
|
|
|
|
|
|
def _communicate(
|
2023-12-22 02:44:00 +00:00
|
|
|
object: Any,
|
2023-11-16 12:15:59 +00:00
|
|
|
send_dst: Optional[int],
|
|
|
|
recv_src: Optional[int],
|
2024-06-26 06:48:02 +00:00
|
|
|
overlap_p2p: bool,
|
2023-11-16 12:15:59 +00:00
|
|
|
send_group: Optional[ProcessGroup] = None,
|
|
|
|
recv_group: Optional[ProcessGroup] = None,
|
2023-12-22 02:44:00 +00:00
|
|
|
send_metadata: bool = True,
|
|
|
|
metadata_recv: Optional[P2PMetadata] = None,
|
2024-06-26 06:48:02 +00:00
|
|
|
send_first: Optional[bool] = None,
|
2023-11-16 12:15:59 +00:00
|
|
|
) -> Any:
|
2023-12-22 02:44:00 +00:00
|
|
|
"""
|
|
|
|
Send and receive object from send_dst and recv_src respectively
|
|
|
|
|
|
|
|
Args:
|
|
|
|
object (Any): object needed to be sent
|
|
|
|
send_dst (int): rank of the destination
|
|
|
|
recv_src (int): rank of the source
|
2024-06-26 06:48:02 +00:00
|
|
|
overlap_p2p (bool): whether to overlap p2p communication with computation
|
2023-12-22 02:44:00 +00:00
|
|
|
send_group (ProcessGroup, optional): process group of sender
|
|
|
|
recv_group (ProcessGroup, optional): process group of receiver
|
|
|
|
send_metadata (bool, optional): whether to send metadata
|
|
|
|
metadata_recv (P2PMetadata, optional): metadata of the object to be received
|
|
|
|
"""
|
|
|
|
assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None"
|
|
|
|
assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None"
|
|
|
|
assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None"
|
|
|
|
assert (
|
2024-01-08 07:37:27 +00:00
|
|
|
metadata_recv is None or len(metadata_recv.non_tensor_obj_idx) == 0
|
|
|
|
), "metadata_recv should not contain non-tensor objects"
|
|
|
|
|
|
|
|
metadata_send, tensor_objs = None, None
|
|
|
|
if object is not None:
|
|
|
|
# NOTE: if object contains non-tensor objects, we have to send metadata
|
|
|
|
metadata_send, tensor_objs = create_send_metadata(object, strict=False, return_tensor=True)
|
|
|
|
send_metadata = send_metadata or len(metadata_send.non_tensor_obj_idx) > 0
|
2024-06-26 06:48:02 +00:00
|
|
|
else:
|
|
|
|
send_metadata = False
|
2023-12-22 02:44:00 +00:00
|
|
|
|
|
|
|
assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group)
|
2024-01-08 07:37:27 +00:00
|
|
|
current_send_device, is_send_nccl_backend = _check_device(send_group)
|
|
|
|
current_recv_device, is_recv_nccl_backend = _check_device(recv_group)
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
|
|
|
|
|
|
|
|
assert current_send_device == current_recv_device
|
|
|
|
current_device = current_send_device
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None):
|
|
|
|
# Send and receive metadata
|
|
|
|
_metadata_recv = _send_recv_serialization_object(
|
|
|
|
object=metadata_send,
|
|
|
|
send_dst=send_dst if send_metadata else None,
|
|
|
|
recv_src=recv_src if metadata_recv is None else None,
|
|
|
|
send_group=send_group if send_metadata else None,
|
|
|
|
recv_group=recv_group if metadata_recv is None else None,
|
|
|
|
current_device=current_device,
|
|
|
|
is_nccl_backend=is_nccl_backend,
|
2024-06-26 06:48:02 +00:00
|
|
|
send_first=send_first if send_first != None else True,
|
2023-12-22 02:44:00 +00:00
|
|
|
)
|
2024-06-26 06:48:02 +00:00
|
|
|
assert (
|
|
|
|
metadata_recv is None or _metadata_recv is None
|
|
|
|
), "You shouldn't receive metadata when using the cached metadata"
|
2023-12-22 02:44:00 +00:00
|
|
|
metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv
|
2023-11-16 12:15:59 +00:00
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
# Send and receive data
|
2024-01-08 07:37:27 +00:00
|
|
|
recv_tensor_metadata = None if metadata_recv is None else metadata_recv.tensor_metadata
|
2024-06-26 06:48:02 +00:00
|
|
|
recv_tensor_objs, wait_handles = _batch_send_recv_tensor(
|
|
|
|
tensor_objs,
|
|
|
|
recv_tensor_metadata,
|
|
|
|
send_dst,
|
|
|
|
recv_src,
|
|
|
|
send_group,
|
|
|
|
recv_group,
|
|
|
|
current_device,
|
|
|
|
overlap_p2p=overlap_p2p,
|
|
|
|
send_first=send_first if send_first != None else True,
|
2023-12-22 02:44:00 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if metadata_recv is not None:
|
|
|
|
assert isinstance(metadata_recv, P2PMetadata)
|
2024-01-08 07:37:27 +00:00
|
|
|
tree_spec = metadata_recv.tree_spec
|
|
|
|
non_tensor_obj_idx = metadata_recv.non_tensor_obj_idx
|
|
|
|
non_tensor_objs = metadata_recv.non_tensor_objs
|
|
|
|
|
|
|
|
if recv_tensor_objs is None:
|
|
|
|
recv_tensor_objs = []
|
|
|
|
|
|
|
|
for idx in non_tensor_obj_idx:
|
|
|
|
recv_tensor_objs.insert(idx, non_tensor_objs.pop(0))
|
|
|
|
recv_object = tree_unflatten(recv_tensor_objs, tree_spec)
|
2024-06-26 06:48:02 +00:00
|
|
|
return recv_object, wait_handles
|
2024-01-08 07:37:27 +00:00
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
return None, wait_handles
|
2023-06-28 09:12:19 +00:00
|
|
|
|
|
|
|
|
2023-10-11 03:40:06 +00:00
|
|
|
def _p2p_comm(
|
|
|
|
tensor_send_next: torch.Tensor,
|
|
|
|
recv_prev: bool,
|
|
|
|
peer: int,
|
|
|
|
group: ProcessGroup,
|
|
|
|
comm_dtype: torch.dtype = torch.float16,
|
|
|
|
):
|
2023-10-18 03:46:37 +00:00
|
|
|
"""
|
2023-10-11 03:40:06 +00:00
|
|
|
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
Args:
|
2023-10-11 03:40:06 +00:00
|
|
|
tensor_send_next (torch.Tensor): tensor to be sent to next stage
|
|
|
|
recv_prev (bool): whether to receive tensor from previous stage
|
|
|
|
peer (int): rank of the peer
|
|
|
|
group (ProcessGroup): process group
|
|
|
|
comm_dtype (torch.dtype): dtype of the tensor to be sent
|
2023-10-18 03:46:37 +00:00
|
|
|
|
2023-10-11 03:40:06 +00:00
|
|
|
Returns:
|
|
|
|
torch.Tensor: tensor received from previous stage
|
|
|
|
"""
|
|
|
|
# send and recv shape
|
|
|
|
send_next_shape = None
|
|
|
|
recv_prev_shape = None
|
|
|
|
|
|
|
|
if tensor_send_next is not None:
|
|
|
|
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
|
|
|
|
if recv_prev:
|
|
|
|
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
|
|
|
|
|
|
|
|
ops = []
|
|
|
|
if send_next_shape is not None:
|
|
|
|
send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group)
|
|
|
|
ops.append(send_next_op)
|
|
|
|
if recv_prev_shape is not None:
|
|
|
|
recv_prev_op = dist.P2POp(
|
|
|
|
dist.irecv,
|
|
|
|
recv_prev_shape,
|
|
|
|
peer=peer,
|
|
|
|
group=group,
|
|
|
|
)
|
|
|
|
ops.append(recv_prev_op)
|
|
|
|
if len(ops) > 0:
|
|
|
|
reqs = dist.batch_isend_irecv(ops)
|
|
|
|
for req in reqs:
|
|
|
|
req.wait()
|
|
|
|
|
|
|
|
if recv_prev_shape is not None:
|
|
|
|
recv_prev_shape = recv_prev_shape.tolist()
|
|
|
|
|
|
|
|
# send and recv data
|
|
|
|
tensor_recv_prev = None
|
|
|
|
if recv_prev:
|
|
|
|
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
|
|
|
|
|
|
|
|
ops = []
|
|
|
|
if tensor_send_next is not None:
|
|
|
|
send_next_op = dist.P2POp(
|
|
|
|
dist.isend,
|
|
|
|
tensor_send_next,
|
|
|
|
peer=peer,
|
|
|
|
group=group,
|
|
|
|
)
|
|
|
|
ops.append(send_next_op)
|
|
|
|
if tensor_recv_prev is not None:
|
|
|
|
recv_prev_op = dist.P2POp(
|
|
|
|
dist.irecv,
|
|
|
|
tensor_recv_prev,
|
|
|
|
peer=peer,
|
|
|
|
group=group,
|
|
|
|
)
|
|
|
|
ops.append(recv_prev_op)
|
|
|
|
if len(ops) > 0:
|
|
|
|
reqs = dist.batch_isend_irecv(ops)
|
|
|
|
for req in reqs:
|
|
|
|
req.wait()
|
|
|
|
return tensor_recv_prev
|
|
|
|
|
|
|
|
|
2023-06-28 09:12:19 +00:00
|
|
|
class PipelineP2PCommunication:
|
2024-06-26 06:48:02 +00:00
|
|
|
def __init__(self, stage_manager: PipelineStageManager, overlap_p2p: bool = True) -> None:
|
2023-06-28 09:12:19 +00:00
|
|
|
self.stage_manager = stage_manager
|
2024-06-26 06:48:02 +00:00
|
|
|
self.overlap_p2p = overlap_p2p
|
2023-06-28 09:12:19 +00:00
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
def recv_forward(
|
|
|
|
self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
|
|
|
|
) -> Tuple[Any, List]:
|
2023-06-28 09:12:19 +00:00
|
|
|
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
prev_rank (int, optional): The rank of the source of the tensor.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: The input tensor or input tensor list.
|
2024-06-26 06:48:02 +00:00
|
|
|
List: List of handles for the communication requests, if overlap is enabled.
|
2023-06-28 09:12:19 +00:00
|
|
|
"""
|
2023-08-16 11:29:03 +00:00
|
|
|
if prev_rank is None:
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
2024-06-26 06:48:02 +00:00
|
|
|
input_tensor, wait_handles = _communicate(
|
|
|
|
object=None,
|
|
|
|
recv_src=prev_rank,
|
|
|
|
send_dst=None,
|
|
|
|
recv_group=self.stage_manager.get_p2p_process_group(),
|
2024-01-03 03:34:49 +00:00
|
|
|
metadata_recv=metadata_recv,
|
2024-06-26 06:48:02 +00:00
|
|
|
overlap_p2p=self.overlap_p2p,
|
2023-12-22 02:44:00 +00:00
|
|
|
)
|
2023-06-28 09:12:19 +00:00
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
return input_tensor, wait_handles
|
2023-06-28 09:12:19 +00:00
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
def recv_backward(
|
|
|
|
self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None
|
|
|
|
) -> Tuple[Any, List]:
|
2023-06-28 09:12:19 +00:00
|
|
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
|
|
|
Args:
|
|
|
|
next_rank (int, optional): The rank of the source of the tensor.
|
|
|
|
|
|
|
|
Returns:
|
2024-06-26 06:48:02 +00:00
|
|
|
Any: The input tensor or input tensor list.
|
|
|
|
List: List of handles for the communication requests, if overlap is enabled.
|
2023-06-28 09:12:19 +00:00
|
|
|
"""
|
2023-08-16 11:29:03 +00:00
|
|
|
if next_rank is None:
|
|
|
|
next_rank = self.stage_manager.get_next_rank()
|
2024-06-26 06:48:02 +00:00
|
|
|
|
|
|
|
output_tensor_grad, wait_handles = _communicate(
|
|
|
|
object=None,
|
|
|
|
recv_src=next_rank,
|
|
|
|
send_dst=None,
|
|
|
|
recv_group=self.stage_manager.get_p2p_process_group(),
|
2024-01-03 03:34:49 +00:00
|
|
|
metadata_recv=metadata_recv,
|
2024-06-26 06:48:02 +00:00
|
|
|
overlap_p2p=self.overlap_p2p,
|
2023-09-19 06:20:26 +00:00
|
|
|
)
|
2023-06-28 09:12:19 +00:00
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
return output_tensor_grad, wait_handles
|
2023-06-28 09:12:19 +00:00
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> List:
|
2023-06-28 09:12:19 +00:00
|
|
|
"""Sends the input tensor to the next stage in pipeline.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
output_object (Any): Object to be sent.
|
|
|
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
2024-06-26 06:48:02 +00:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
List: List of handles for the communication requests, if overlap is enabled.
|
2023-06-28 09:12:19 +00:00
|
|
|
"""
|
2023-08-16 11:29:03 +00:00
|
|
|
if next_rank is None:
|
|
|
|
next_rank = self.stage_manager.get_next_rank()
|
2024-06-26 06:48:02 +00:00
|
|
|
_, handles = _communicate(
|
2023-12-22 02:44:00 +00:00
|
|
|
output_object,
|
2024-06-26 06:48:02 +00:00
|
|
|
recv_src=None,
|
|
|
|
send_dst=next_rank,
|
|
|
|
send_group=self.stage_manager.get_p2p_process_group(),
|
2024-01-03 03:34:49 +00:00
|
|
|
send_metadata=send_metadata,
|
2024-06-26 06:48:02 +00:00
|
|
|
overlap_p2p=self.overlap_p2p,
|
2023-12-22 02:44:00 +00:00
|
|
|
)
|
2024-06-26 06:48:02 +00:00
|
|
|
return handles
|
2023-06-28 09:12:19 +00:00
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> List:
|
2023-06-28 09:12:19 +00:00
|
|
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_object (Any): Object to be sent.
|
|
|
|
prev_rank (int, optional): The rank of the recipient of the tensor
|
2024-06-26 06:48:02 +00:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
List: List of handles for the communication requests, if overlap is enabled.
|
2023-06-28 09:12:19 +00:00
|
|
|
"""
|
2023-08-16 11:29:03 +00:00
|
|
|
if prev_rank is None:
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
2024-06-26 06:48:02 +00:00
|
|
|
_, handles = _communicate(
|
2023-12-22 02:44:00 +00:00
|
|
|
input_object,
|
2024-06-26 06:48:02 +00:00
|
|
|
recv_src=None,
|
|
|
|
send_dst=prev_rank,
|
|
|
|
send_group=self.stage_manager.get_p2p_process_group(),
|
2024-01-03 03:34:49 +00:00
|
|
|
send_metadata=send_metadata,
|
2024-06-26 06:48:02 +00:00
|
|
|
overlap_p2p=self.overlap_p2p,
|
2023-12-22 02:44:00 +00:00
|
|
|
)
|
2024-06-26 06:48:02 +00:00
|
|
|
return handles
|
2023-10-11 03:40:06 +00:00
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
def send_forward_recv_forward(
|
|
|
|
self,
|
|
|
|
output_object: Any,
|
|
|
|
is_send: bool,
|
|
|
|
is_recv: bool,
|
|
|
|
send_first: bool,
|
|
|
|
send_metadata: bool = True,
|
|
|
|
metadata_recv: Optional[P2PMetadata] = None,
|
|
|
|
) -> Tuple[Any, List]:
|
|
|
|
"""Sends the input tensor to the next pipeline stage and copy the output tensor from the next pipeline stage
|
|
|
|
|
|
|
|
Args:
|
|
|
|
output_object (Any): Object to be sent.
|
|
|
|
is_send (bool): Whether to send the input tensor to the next pipeline stage.
|
|
|
|
is_recv (bool): Whether to copy the output tensor from the next pipeline stage.
|
|
|
|
send_first (bool): Whether to send before receive.
|
|
|
|
send_metadata (bool, optional): Whether to send metadata.
|
|
|
|
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: The input tensor or input tensor list.
|
|
|
|
List: List of handles for the communication requests, if overlap is enabled.
|
|
|
|
"""
|
|
|
|
next_rank = self.stage_manager.get_next_rank() if is_send else None
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank() if is_recv else None
|
|
|
|
group = self.stage_manager.get_p2p_process_group()
|
|
|
|
return _communicate(
|
|
|
|
output_object,
|
|
|
|
send_dst=next_rank,
|
|
|
|
recv_src=prev_rank,
|
|
|
|
send_group=group if is_send else None,
|
|
|
|
recv_group=group if is_recv else None,
|
|
|
|
send_metadata=send_metadata if is_send else False,
|
|
|
|
metadata_recv=metadata_recv if is_recv else None,
|
|
|
|
send_first=send_first,
|
|
|
|
overlap_p2p=self.overlap_p2p,
|
|
|
|
)
|
|
|
|
|
|
|
|
def send_backward_recv_backward(
|
2023-12-22 02:44:00 +00:00
|
|
|
self,
|
|
|
|
input_object: Any,
|
2024-06-26 06:48:02 +00:00
|
|
|
is_send: bool,
|
|
|
|
is_recv: bool,
|
|
|
|
send_first: bool,
|
2023-12-22 02:44:00 +00:00
|
|
|
send_metadata: bool = True,
|
|
|
|
metadata_recv: Optional[P2PMetadata] = None,
|
2024-06-26 06:48:02 +00:00
|
|
|
) -> Tuple[Any, List]:
|
|
|
|
"""Sends the gradient tensor to the previous pipeline stage and copy the gradient tensor from the previous pipeline stage
|
2023-11-16 12:15:59 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
input_object (Any): Object to be sent.
|
2024-06-26 06:48:02 +00:00
|
|
|
is_send (bool): Whether to send the gradient tensor to the previous pipeline stage.
|
|
|
|
is_recv (bool): Whether to copy the gradient tensor from the previous pipeline stage.
|
|
|
|
send_first (bool): Whether to send before receive.
|
|
|
|
send_metadata (bool, optional): Whether to send metadata.
|
|
|
|
metadata_recv (P2PMetadata, optional): The cached metadata(size, type) of the object to be received.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: The input tensor or input tensor list.
|
|
|
|
List: List of handles for the communication requests, if overlap is enabled.
|
2023-11-16 12:15:59 +00:00
|
|
|
"""
|
2024-06-26 06:48:02 +00:00
|
|
|
prev_rank = self.stage_manager.get_prev_rank() if is_send else None
|
|
|
|
next_rank = self.stage_manager.get_next_rank() if is_recv else None
|
|
|
|
|
|
|
|
group = self.stage_manager.get_p2p_process_group()
|
2023-11-16 12:15:59 +00:00
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
return _communicate(
|
|
|
|
input_object,
|
|
|
|
send_dst=prev_rank,
|
|
|
|
recv_src=next_rank,
|
|
|
|
send_group=group if is_send else None,
|
|
|
|
recv_group=group if is_recv else None,
|
|
|
|
send_metadata=send_metadata if is_send else False,
|
|
|
|
metadata_recv=metadata_recv if is_recv else None,
|
|
|
|
send_first=send_first,
|
|
|
|
overlap_p2p=self.overlap_p2p,
|
|
|
|
)
|
|
|
|
|
|
|
|
def send_forward_recv_backward(
|
|
|
|
self,
|
|
|
|
input_object: Any,
|
|
|
|
send_metadata: bool = True,
|
|
|
|
metadata_recv: Optional[P2PMetadata] = None,
|
|
|
|
send_first: Optional[bool] = None,
|
|
|
|
) -> Tuple[Any, List]:
|
|
|
|
"""Sends the gradient tensor to and copy the gradient tensor from the next pipeline stage
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_object (Any): Object to be sent.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Any: The input tensor or input tensor list.
|
|
|
|
List: List of handles for the communication requests, if overlap is enabled.
|
|
|
|
"""
|
|
|
|
next_rank = self.stage_manager.get_next_rank()
|
|
|
|
group = self.stage_manager.get_p2p_process_group()
|
2023-11-16 12:15:59 +00:00
|
|
|
return _communicate(
|
2023-12-22 02:44:00 +00:00
|
|
|
input_object,
|
|
|
|
next_rank,
|
|
|
|
next_rank,
|
|
|
|
send_group=group,
|
|
|
|
recv_group=group,
|
|
|
|
send_metadata=send_metadata,
|
|
|
|
metadata_recv=metadata_recv,
|
2024-06-26 06:48:02 +00:00
|
|
|
send_first=send_first,
|
|
|
|
overlap_p2p=False,
|
2023-11-16 12:15:59 +00:00
|
|
|
)
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
def send_backward_recv_forward(
|
|
|
|
self,
|
|
|
|
input_object: Any,
|
|
|
|
send_metadata: bool = True,
|
|
|
|
metadata_recv: Optional[P2PMetadata] = None,
|
2024-06-26 06:48:02 +00:00
|
|
|
send_first: Optional[bool] = None,
|
|
|
|
) -> Tuple[Any, List]:
|
2023-11-16 12:15:59 +00:00
|
|
|
"""Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_object (Any): Object to be sent.
|
|
|
|
|
2024-06-26 06:48:02 +00:00
|
|
|
Returns:
|
|
|
|
Any: The input tensor or input tensor list.
|
|
|
|
List: List of handles for the communication requests, if overlap is enabled.
|
|
|
|
"""
|
|
|
|
prev_rank = self.stage_manager.get_prev_rank()
|
|
|
|
group = self.stage_manager.get_p2p_process_group()
|
2023-11-16 12:15:59 +00:00
|
|
|
return _communicate(
|
|
|
|
input_object,
|
2023-12-22 02:44:00 +00:00
|
|
|
prev_rank,
|
|
|
|
prev_rank,
|
|
|
|
send_group=group,
|
|
|
|
recv_group=group,
|
|
|
|
send_metadata=send_metadata,
|
|
|
|
metadata_recv=metadata_recv,
|
2024-06-26 06:48:02 +00:00
|
|
|
send_first=send_first,
|
|
|
|
overlap_p2p=False,
|
2023-11-16 12:15:59 +00:00
|
|
|
)
|
|
|
|
|
2023-10-18 03:46:37 +00:00
|
|
|
def p2p_communicate(
|
2023-12-22 02:44:00 +00:00
|
|
|
self,
|
|
|
|
output_object: Any,
|
|
|
|
recv_pre: bool,
|
|
|
|
next_rank: Optional[int] = None,
|
|
|
|
comm_dtype: torch.dtype = torch.float16,
|
2024-06-26 06:48:02 +00:00
|
|
|
) -> Any:
|
2023-10-11 03:40:06 +00:00
|
|
|
"""
|
|
|
|
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
output_object (Any): Object to be sent.
|
|
|
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
|
|
|
"""
|
2023-12-22 02:44:00 +00:00
|
|
|
if next_rank is None:
|
|
|
|
next_rank = self.stage_manager.get_next_rank()
|
2023-10-18 03:46:37 +00:00
|
|
|
recv_tensor = _p2p_comm(
|
2023-12-22 02:44:00 +00:00
|
|
|
output_object,
|
|
|
|
recv_pre,
|
|
|
|
next_rank,
|
2024-06-26 06:48:02 +00:00
|
|
|
self.stage_manager.get_p2p_process_group(),
|
2023-12-22 02:44:00 +00:00
|
|
|
comm_dtype,
|
2023-10-18 03:46:37 +00:00
|
|
|
)
|
2023-10-11 03:40:06 +00:00
|
|
|
return recv_tensor
|