|
|
|
import gc
|
|
|
|
import itertools
|
|
|
|
from functools import reduce
|
|
|
|
from operator import mul
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from torch.distributed.distributed_c10d import GroupMember
|
|
|
|
|
|
|
|
|
|
|
|
def prod(nums: List[int]) -> int:
|
|
|
|
"""Product of a list of numbers.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
nums (List[int]): A list of numbers.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: The product of the numbers.
|
|
|
|
"""
|
|
|
|
return reduce(mul, nums)
|
|
|
|
|
|
|
|
|
|
|
|
class ProcessGroupMesh:
|
|
|
|
"""A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method.
|
|
|
|
It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation.
|
|
|
|
|
|
|
|
We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process.
|
|
|
|
For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
*size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
shape (Tuple[int, ...]): The shape of the process group mesh.
|
|
|
|
rank (int): The rank of the current process.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, *size: int) -> None:
|
|
|
|
assert dist.is_initialized(), "Please initialize torch.distributed first."
|
|
|
|
world_size = dist.get_world_size()
|
|
|
|
prod_size = prod(size)
|
|
|
|
assert (
|
|
|
|
prod_size == world_size
|
|
|
|
), f"The product of the size({prod_size}) must be equal to the world size({world_size})."
|
|
|
|
|
|
|
|
self._shape = size
|
|
|
|
self._rank = dist.get_rank()
|
|
|
|
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
|
|
|
|
self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember.NON_GROUP_MEMBER]] = {}
|
|
|
|
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
|
|
|
|
|
|
|
|
def destroy_mesh_process_groups(self):
|
|
|
|
r"""
|
|
|
|
Destructor method for the ProcessGroupMesh class.
|
|
|
|
|
|
|
|
When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
|
|
|
|
cleaning up any process groups that were created during the lifetime of the object.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
|
|
|
|
when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
|
|
|
|
system resources.
|
|
|
|
"""
|
|
|
|
for group in self._ranks_to_group.values():
|
|
|
|
dist.destroy_process_group(group)
|
|
|
|
|
|
|
|
# Manually clear all process groups to save memory
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def shape(self) -> Tuple[int, ...]:
|
|
|
|
return self._shape
|
|
|
|
|
|
|
|
@property
|
|
|
|
def rank(self) -> int:
|
|
|
|
return self._rank
|
|
|
|
|
|
|
|
def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
|
|
|
|
"""Get the size of the process group mesh.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh.
|
|
|
|
"""
|
|
|
|
if dim is None:
|
|
|
|
return self._shape
|
|
|
|
else:
|
|
|
|
return self._shape[dim]
|
|
|
|
|
|
|
|
def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
|
|
|
|
"""Get the coordinate of the process group mesh.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh.
|
|
|
|
"""
|
|
|
|
if dim is None:
|
|
|
|
return self._coord
|
|
|
|
else:
|
|
|
|
return self._coord[dim]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
|
|
|
|
"""Convert a rank to a coordinate.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
rank (int): Rank to be converted.
|
|
|
|
shape (Tuple[int, ...]): Shape of the process group mesh.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[int, ...]: Coordinate of the rank.
|
|
|
|
"""
|
|
|
|
return np.unravel_index(rank, shape)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int:
|
|
|
|
"""Convert a coordinate to a rank.
|
|
|
|
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
|
|
|
|
with wrap, index out of range would be wrapped around.
|
|
|
|
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
coords (Tuple[int, ...]): Coordinate to be converted.
|
|
|
|
shape (Tuple[int, ...]): Shape of the process group mesh.
|
|
|
|
mode (Optional[str]): The mode for numpy.ravel_multi_index.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: Rank of the coordinate.
|
|
|
|
"""
|
|
|
|
|
|
|
|
assert mode in ["raise", "wrap", "clip"]
|
|
|
|
return int(np.ravel_multi_index(coord, shape, mode))
|
|
|
|
|
|
|
|
def _get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
|
|
|
|
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
ranks_in_group (List[int]): Ranks in the process group.
|
|
|
|
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ProcessGroup: The process group with the given ranks.
|
|
|
|
"""
|
|
|
|
ranks_in_group = sorted(ranks_in_group)
|
|
|
|
if tuple(ranks_in_group) not in self._ranks_to_group:
|
|
|
|
group = dist.new_group(ranks_in_group, backend=backend)
|
|
|
|
self._ranks_to_group[tuple(ranks_in_group)] = group
|
|
|
|
if group is not GroupMember.NON_GROUP_MEMBER:
|
|
|
|
self._group_to_ranks[group] = tuple(ranks_in_group)
|
|
|
|
return self._ranks_to_group[tuple(ranks_in_group)]
|
|
|
|
|
|
|
|
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
|
|
|
|
"""Get the ranks in the given process group. The process group must be created by this class.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
group (ProcessGroup): The process group.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[int]: Ranks in the process group.
|
|
|
|
"""
|
|
|
|
return list(self._group_to_ranks[group])
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_coords_along_axis(
|
|
|
|
base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]]
|
|
|
|
) -> List[Tuple[int, ...]]:
|
|
|
|
"""Get coordinates along the given axis.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on.
|
|
|
|
axis (int): Axis along which the coordinates are generated.
|
|
|
|
indices_at_axis (List[int]): Indices at the axis.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[Tuple[int, ...]]: Coordinates along the axis.
|
|
|
|
"""
|
|
|
|
if isinstance(axis, int):
|
|
|
|
axis = [
|
|
|
|
axis,
|
|
|
|
]
|
|
|
|
assert isinstance(indices_at_axis[0], int), f"Expected int, but got {type(indices_at_axis[0])}."
|
|
|
|
indices_at_axis = [
|
|
|
|
indices_at_axis,
|
|
|
|
]
|
|
|
|
|
|
|
|
def add_index(base_coord, axis, indices_at_axis):
|
|
|
|
coords_in_group = []
|
|
|
|
for idx in indices_at_axis:
|
|
|
|
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
|
|
|
|
return coords_in_group
|
|
|
|
|
|
|
|
coords_in_group = [base_coord]
|
|
|
|
for ax, indices_at_ax in zip(axis, indices_at_axis):
|
|
|
|
new_coords_in_group = []
|
|
|
|
for coords in coords_in_group:
|
|
|
|
new_coords_in_group += add_index(coords, ax, indices_at_ax)
|
|
|
|
coords_in_group = new_coords_in_group
|
|
|
|
|
|
|
|
return coords_in_group
|
|
|
|
|
|
|
|
def create_group_along_axis(
|
|
|
|
self,
|
|
|
|
axis: Union[int, List[int]],
|
|
|
|
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
|
|
|
|
backend: Optional[str] = None,
|
|
|
|
) -> ProcessGroup:
|
|
|
|
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
axis (int): Axis along which the process groups are created.
|
|
|
|
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
|
|
|
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ProcessGroup: The process group along the given axis which the current process belongs to.
|
|
|
|
"""
|
|
|
|
if isinstance(axis, int):
|
|
|
|
axis = [
|
|
|
|
axis,
|
|
|
|
]
|
|
|
|
if indices_at_axis is not None:
|
|
|
|
assert isinstance(indices_at_axis[0], int)
|
|
|
|
indices_at_axis = [
|
|
|
|
indices_at_axis,
|
|
|
|
]
|
|
|
|
|
|
|
|
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
|
|
|
|
reduced_shape = list(self._shape)
|
|
|
|
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
|
|
|
for ax in axis:
|
|
|
|
reduced_shape[ax] = 1
|
|
|
|
target_group = None
|
|
|
|
# use Cartesian product to generate all combinations of coordinates
|
|
|
|
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
|
|
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
|
|
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
|
|
|
group = self._get_group(ranks_in_group, backend=backend)
|
|
|
|
if self._rank in ranks_in_group:
|
|
|
|
target_group = group
|
|
|
|
return target_group
|
|
|
|
|
|
|
|
def get_group_along_axis(
|
|
|
|
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
|
|
|
) -> ProcessGroup:
|
|
|
|
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
|
|
|
|
|
|
|
Args:
|
[MoE/ZeRO] Moe refactor with zero refactor (#5821)
* [moe] removed openmoe-coupled code and rectify mixstral code (#5471)
* [Feauture] MoE refractor; Intergration with Mixtral (#5682)
* cherry pick from refractor-moe branch
* tests passed
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support ep + zero
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* add mixtral auto policy & move pipeline forward code to modeling folder
* [moe refactor] modify kernel test without Route Class
* [moe refactor] add moe tensor test path environment variable to github workflow
* fix typos
* fix moe test bug due to the code rebase
* [moe refactor] fix moe zero test, and little bug in low level zero
* fix typo
* add moe tensor path to github workflow
* remove some useless code
* fix typo & unify global variable XX_AXIS logic without using -1
* fix typo & prettifier the code
* remove print code & support zero 2 test
* remove useless code
* reanme function
* fix typo
* fix typo
* Further improve the test code
* remove print code
* [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test
* [moe refactor] skip some unit test which will be refactored later
* [moe refactor] fix unit import error
* [moe refactor] fix circular import issues
* [moe refactor] remove debug code
* [moe refactor] update github workflow
* [moe/zero] refactor low level optimizer (#5767)
* [zero] refactor low level optimizer
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] MoE refactor with newest version of ZeRO (#5801)
* [zero] remove redundant members in BucketStore (#5802)
* [zero] align api with previous version
* [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819)
* [moe refactor] update unit test with the refactored ZeRO and remove useless test
* move moe checkpoint to checkpoint folder and exchange global axis to class member
* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug
* fix zero unit test
* Add an assertion to prevent users from using it incorrectly
* [hotfix]Solve the compatibility issue of zero refactor (#5823)
* [moe refactor] update unit test with the refactored ZeRO and remove useless test
* move moe checkpoint to checkpoint folder and exchange global axis to class member
* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug
* fix zero unit test
* Add an assertion to prevent users from using it incorrectly
* Modify function parameter names to resolve compatibility issues
* [zero] fix missing hook removal (#5824)
* [MoE] Resolve .github conflict (#5829)
* [Fix/Example] Fix Llama Inference Loading Data Type (#5763)
* [fix/example] fix llama inference loading dtype
* revise loading dtype of benchmark llama3
* [release] update version (#5752)
* [release] update version
* [devops] update compatibility test
* [devops] update compatibility test
* [devops] update compatibility test
* [devops] update compatibility test
* [test] fix ddp plugin test
* [test] fix gptj and rpc test
* [devops] fix cuda ext compatibility
* [inference] fix flash decoding test
* [inference] fix flash decoding test
* fix (#5765)
* [test] Fix/fix testcase (#5770)
* [fix] branch for fix testcase;
* [fix] fix test_analyzer & test_auto_parallel;
* [fix] remove local change about moe;
* [fix] rm local change moe;
* [Hotfix] Add missing init file in inference.executor (#5774)
* [CI/tests] simplify some test case to reduce testing time (#5755)
* [ci/tests] simplify some test case to reduce testing time
* [ci/tests] continue to remove test case to reduce ci time cost
* restore some test config
* [ci/tests] continue to reduce ci time cost
* [misc] update dockerfile (#5776)
* [misc] update dockerfile
* [misc] update dockerfile
* [devops] fix docker ci (#5780)
* [Inference]Add Streaming LLM (#5745)
* Add Streaming LLM
* add some parameters to llama_generation.py
* verify streamingllm config
* add test_streamingllm.py
* modified according to the opinions of review
* add Citation
* change _block_tables tolist
* [hotfix] fix llama flash attention forward (#5777)
* [misc] Accelerate CI for zero and dist optim (#5758)
* remove fp16 from lamb
* remove d2h copy in checking states
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Test/CI] remove test cases to reduce CI duration (#5753)
* [test] smaller gpt2 test case
* [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py
* [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py
* [test] reduce test cases tests/test_zero/test_gemini/test_optim.py
* Revert "[test] smaller gpt2 test case"
Some tests might depend on the size of model (num of chunks)
This reverts commit df705a5210b8901645992adf276e320e48766ebf.
* [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py
* [CI] smaller test model for two mwo the two modifid cases
* [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there
* [hotfix] fix testcase in test_fx/test_tracer (#5779)
* [fix] branch for fix testcase;
* [fix] fix test_analyzer & test_auto_parallel;
* [fix] remove local change about moe;
* [fix] rm local change moe;
* [fix] fix test_deepfm_model & test_dlrf_model;
* [fix] fix test_hf_albert & test_hf_gpt;
* [gemini] optimize reduce scatter d2h copy (#5760)
* [gemini] optimize reduce scatter d2h copy
* [fix] fix missing reduce variable
* [refactor] remove legacy async reduce scatter code
* [gemini] missing sync
* Revert "[refactor] remove legacy async reduce scatter code"
This reverts commit 58ad76d4665032bbe548d066116d1c572ce98979.
* [gemini] further optimize with async all reduce
* [fix] pass flag from manager to chunk
* Allow building cuda extension without a device. (#5535)
Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are.
* [misc] fix dist logger (#5782)
* [install]fix setup (#5786)
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update requirements (#5787)
* [shardformer] fix import (#5788)
* upgrade colossal-chat support tp_group>1, add sp for sft
* upgrade ppo dpo rm script
* run pre-commit
* moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy
* fix training script
* fix ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix transformers version
* remove duplicated test
* fix datasets version
* remove models that require huggingface auth from ci
* remove local data path
* update ci
* remove baichuan from template test due to transformer version conflict
* merge
* Refactor modeling by adding attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Fix tests and naming
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Pass inference model shard configs for module init
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Clean up
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* replace the customized dataloader setup with the build-in one
* replace the customized dataloader setup with the build-in one
* Remove flash attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* fix readme
* Fix test import
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* update sft trainning script
* [Inference]refactor baichuan (#5791)
* refactor baichuan
* remove unused code and add TODO for lazyinit
* [test] fix chatglm test kit (#5793)
* [shardformer] fix modeling of bloom and falcon (#5796)
* [test] fix qwen2 pytest distLarge (#5797)
* [Inference] Fix flash-attn import and add model test (#5794)
* Fix torch int32 dtype
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Fix flash-attn import
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Add generalized model test
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Remove exposed path to model
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Add default value for use_flash_attn
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Rename model test
Signed-off-by: char-1ee <xingjianli59@gmail.com>
---------
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* [Gemini] Use async stream to prefetch and h2d data moving (#5781)
* use async stream to prefetch and h2d data moving
* Remove redundant code
* [gemini] quick fix on possible async operation (#5803)
* [gemini] quick fix on possible async operation
* [gemini] quick fix on possible async operation
* [shardformer] upgrade transformers to 4.39.3 (#5815)
* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)
* [shardformer] fix modeling of gpt2 and gptj
* [shardformer] fix whisper modeling
* [misc] update requirements
---------
Co-authored-by: ver217 <lhx0217@gmail.com>
* [shardformer]upgrade transformers for mistral (#5808)
* upgrade transformers for mistral
* fix
* fix
* [shardformer]upgrade transformers for llama (#5809)
* update transformers
fix
* fix
* fix
* [inference] upgrade transformers (#5810)
* update transformers
fix
* fix
* fix
* fix
* fix
* [gemini] update transformers for gemini (#5814)
---------
Co-authored-by: ver217 <lhx0217@gmail.com>
* Support 4d parallel + flash attention (#5789)
* support tp + sp + pp
* remove comments
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
---------
Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
* [zero] fix hook bug
* [zero] add low level optimizer back (#5839)
* [zero] fix param & refactor
* [zero] add back original low level opt
* [zero] remove moe related
* [zero] pass zero tests
* [zero] refactor
* [chore] add del func back
* [zero] comments and naming (#5840)
* [zero] modify api (#5843)
* [zero] modify api
* [test] remove _grad_store access in tests
* [test] fix (#5857)
* [CI] skip openmoe CI check
* [CI] fox pre-commit
* [zero] remove redundant memebr init (#5862)
* [misc] remove useless code, modify the pg mesh implementation
* [misc] remove useless code, modify the pg mesh implementation
* [misc] use tempfile
* resolve conflict with main branch
* [misc] use tempfile in test_moe_checkpoint.py
* [misc] remove useless code, add assertion about sequence parallel, move logger into function
* [misc] remove useless code
---------
Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
5 months ago
|
|
|
axis (int or list of int): Axes along which the process groups are created.
|
|
|
|
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
|
|
|
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
ProcessGroup: The process group along the given axis which the current process belongs to.
|
|
|
|
"""
|
[MoE/ZeRO] Moe refactor with zero refactor (#5821)
* [moe] removed openmoe-coupled code and rectify mixstral code (#5471)
* [Feauture] MoE refractor; Intergration with Mixtral (#5682)
* cherry pick from refractor-moe branch
* tests passed
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support ep + zero
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* add mixtral auto policy & move pipeline forward code to modeling folder
* [moe refactor] modify kernel test without Route Class
* [moe refactor] add moe tensor test path environment variable to github workflow
* fix typos
* fix moe test bug due to the code rebase
* [moe refactor] fix moe zero test, and little bug in low level zero
* fix typo
* add moe tensor path to github workflow
* remove some useless code
* fix typo & unify global variable XX_AXIS logic without using -1
* fix typo & prettifier the code
* remove print code & support zero 2 test
* remove useless code
* reanme function
* fix typo
* fix typo
* Further improve the test code
* remove print code
* [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test
* [moe refactor] skip some unit test which will be refactored later
* [moe refactor] fix unit import error
* [moe refactor] fix circular import issues
* [moe refactor] remove debug code
* [moe refactor] update github workflow
* [moe/zero] refactor low level optimizer (#5767)
* [zero] refactor low level optimizer
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] MoE refactor with newest version of ZeRO (#5801)
* [zero] remove redundant members in BucketStore (#5802)
* [zero] align api with previous version
* [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819)
* [moe refactor] update unit test with the refactored ZeRO and remove useless test
* move moe checkpoint to checkpoint folder and exchange global axis to class member
* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug
* fix zero unit test
* Add an assertion to prevent users from using it incorrectly
* [hotfix]Solve the compatibility issue of zero refactor (#5823)
* [moe refactor] update unit test with the refactored ZeRO and remove useless test
* move moe checkpoint to checkpoint folder and exchange global axis to class member
* update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug
* fix zero unit test
* Add an assertion to prevent users from using it incorrectly
* Modify function parameter names to resolve compatibility issues
* [zero] fix missing hook removal (#5824)
* [MoE] Resolve .github conflict (#5829)
* [Fix/Example] Fix Llama Inference Loading Data Type (#5763)
* [fix/example] fix llama inference loading dtype
* revise loading dtype of benchmark llama3
* [release] update version (#5752)
* [release] update version
* [devops] update compatibility test
* [devops] update compatibility test
* [devops] update compatibility test
* [devops] update compatibility test
* [test] fix ddp plugin test
* [test] fix gptj and rpc test
* [devops] fix cuda ext compatibility
* [inference] fix flash decoding test
* [inference] fix flash decoding test
* fix (#5765)
* [test] Fix/fix testcase (#5770)
* [fix] branch for fix testcase;
* [fix] fix test_analyzer & test_auto_parallel;
* [fix] remove local change about moe;
* [fix] rm local change moe;
* [Hotfix] Add missing init file in inference.executor (#5774)
* [CI/tests] simplify some test case to reduce testing time (#5755)
* [ci/tests] simplify some test case to reduce testing time
* [ci/tests] continue to remove test case to reduce ci time cost
* restore some test config
* [ci/tests] continue to reduce ci time cost
* [misc] update dockerfile (#5776)
* [misc] update dockerfile
* [misc] update dockerfile
* [devops] fix docker ci (#5780)
* [Inference]Add Streaming LLM (#5745)
* Add Streaming LLM
* add some parameters to llama_generation.py
* verify streamingllm config
* add test_streamingllm.py
* modified according to the opinions of review
* add Citation
* change _block_tables tolist
* [hotfix] fix llama flash attention forward (#5777)
* [misc] Accelerate CI for zero and dist optim (#5758)
* remove fp16 from lamb
* remove d2h copy in checking states
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Test/CI] remove test cases to reduce CI duration (#5753)
* [test] smaller gpt2 test case
* [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py
* [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py
* [test] reduce test cases tests/test_zero/test_gemini/test_optim.py
* Revert "[test] smaller gpt2 test case"
Some tests might depend on the size of model (num of chunks)
This reverts commit df705a5210b8901645992adf276e320e48766ebf.
* [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py
* [CI] smaller test model for two mwo the two modifid cases
* [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there
* [hotfix] fix testcase in test_fx/test_tracer (#5779)
* [fix] branch for fix testcase;
* [fix] fix test_analyzer & test_auto_parallel;
* [fix] remove local change about moe;
* [fix] rm local change moe;
* [fix] fix test_deepfm_model & test_dlrf_model;
* [fix] fix test_hf_albert & test_hf_gpt;
* [gemini] optimize reduce scatter d2h copy (#5760)
* [gemini] optimize reduce scatter d2h copy
* [fix] fix missing reduce variable
* [refactor] remove legacy async reduce scatter code
* [gemini] missing sync
* Revert "[refactor] remove legacy async reduce scatter code"
This reverts commit 58ad76d4665032bbe548d066116d1c572ce98979.
* [gemini] further optimize with async all reduce
* [fix] pass flag from manager to chunk
* Allow building cuda extension without a device. (#5535)
Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are.
* [misc] fix dist logger (#5782)
* [install]fix setup (#5786)
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update requirements (#5787)
* [shardformer] fix import (#5788)
* upgrade colossal-chat support tp_group>1, add sp for sft
* upgrade ppo dpo rm script
* run pre-commit
* moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy
* fix training script
* fix ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix transformers version
* remove duplicated test
* fix datasets version
* remove models that require huggingface auth from ci
* remove local data path
* update ci
* remove baichuan from template test due to transformer version conflict
* merge
* Refactor modeling by adding attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Fix tests and naming
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Pass inference model shard configs for module init
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Clean up
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* replace the customized dataloader setup with the build-in one
* replace the customized dataloader setup with the build-in one
* Remove flash attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* fix readme
* Fix test import
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* update sft trainning script
* [Inference]refactor baichuan (#5791)
* refactor baichuan
* remove unused code and add TODO for lazyinit
* [test] fix chatglm test kit (#5793)
* [shardformer] fix modeling of bloom and falcon (#5796)
* [test] fix qwen2 pytest distLarge (#5797)
* [Inference] Fix flash-attn import and add model test (#5794)
* Fix torch int32 dtype
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Fix flash-attn import
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Add generalized model test
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Remove exposed path to model
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Add default value for use_flash_attn
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* Rename model test
Signed-off-by: char-1ee <xingjianli59@gmail.com>
---------
Signed-off-by: char-1ee <xingjianli59@gmail.com>
* [Gemini] Use async stream to prefetch and h2d data moving (#5781)
* use async stream to prefetch and h2d data moving
* Remove redundant code
* [gemini] quick fix on possible async operation (#5803)
* [gemini] quick fix on possible async operation
* [gemini] quick fix on possible async operation
* [shardformer] upgrade transformers to 4.39.3 (#5815)
* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)
* [shardformer] fix modeling of gpt2 and gptj
* [shardformer] fix whisper modeling
* [misc] update requirements
---------
Co-authored-by: ver217 <lhx0217@gmail.com>
* [shardformer]upgrade transformers for mistral (#5808)
* upgrade transformers for mistral
* fix
* fix
* [shardformer]upgrade transformers for llama (#5809)
* update transformers
fix
* fix
* fix
* [inference] upgrade transformers (#5810)
* update transformers
fix
* fix
* fix
* fix
* fix
* [gemini] update transformers for gemini (#5814)
---------
Co-authored-by: ver217 <lhx0217@gmail.com>
* Support 4d parallel + flash attention (#5789)
* support tp + sp + pp
* remove comments
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
---------
Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
* [zero] fix hook bug
* [zero] add low level optimizer back (#5839)
* [zero] fix param & refactor
* [zero] add back original low level opt
* [zero] remove moe related
* [zero] pass zero tests
* [zero] refactor
* [chore] add del func back
* [zero] comments and naming (#5840)
* [zero] modify api (#5843)
* [zero] modify api
* [test] remove _grad_store access in tests
* [test] fix (#5857)
* [CI] skip openmoe CI check
* [CI] fox pre-commit
* [zero] remove redundant memebr init (#5862)
* [misc] remove useless code, modify the pg mesh implementation
* [misc] remove useless code, modify the pg mesh implementation
* [misc] use tempfile
* resolve conflict with main branch
* [misc] use tempfile in test_moe_checkpoint.py
* [misc] remove useless code, add assertion about sequence parallel, move logger into function
* [misc] remove useless code
---------
Signed-off-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: botbw <wang1570@e.ntu.edu.sg>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: char-1ee <xingjianli59@gmail.com>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
5 months ago
|
|
|
indices_at_axis = indices_at_axis
|
|
|
|
if indices_at_axis is None:
|
|
|
|
if isinstance(axis, (list, tuple)):
|
|
|
|
indices_at_axis = list(list(range(self._shape[ax])) for ax in axis)
|
|
|
|
else:
|
|
|
|
indices_at_axis = list(range(self._shape[axis]))
|
|
|
|
|
|
|
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
|
|
|
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
|
|
|
if ranks_in_group not in self._ranks_to_group:
|
|
|
|
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
|
|
|
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
|
|
|
return self._ranks_to_group[ranks_in_group]
|