ColossalAI/colossalai/tensor/d_tensor/sharding_spec.py

262 lines
9.6 KiB
Python
Raw Normal View History

from typing import Dict, List
from ..utils import merge_same_dim_mesh_list
from .misc import ShardingOutOfIndexError
__all__ = ["DimSpec", "ShardingException", "ShardingSpec"]
ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
NAN = "nan"
class DimSpec:
"""
Sharding spec for single dimension of the sharded tensor describe the sharding dimension of
logical device mesh and give a method to compute the difference between them.
This class is used internally in ShardingSpec.
Argument:
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
Otherwise, the element in shard_list means the data will be sharded in that dimension.
"""
_DIFFERENCE_DICT = None
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
def __eq__(self, other):
return str(self) == str(other)
def __repr__(self):
if self.is_replica:
return "R"
target = "S"
for dim in self.shard_list:
target += str(dim)
return target
@property
def difference_dict(self):
"""
Returns the difference dict, and lazily initializes it when needed
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
return self._DIFFERENCE_DICT
def dim_diff(self, other):
"""
The difference between two DimSpec.
Argument:
other(DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two DimSpec.
Example:
dim_spec = DimSpec([0])
other_dim_spec = DimSpec([0, 1])
print(dim_spec.dim_diff(other_dim_spec))
Output:
5
"""
difference = self.difference_dict[(str(self), str(other))]
return difference
@classmethod
def _build_difference_2d_dict(cls):
"""
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
"""
source_spec_list = ["R", "S0", "S1", "S01"]
target_spec_list = ["R", "S0", "S1", "S01"]
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
source_shard_list = cls._convert_str_to_shard_list(source_spec)
target_shard_list = cls._convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
difference = 0
# all_gather(source) -> target
elif (
len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list
):
difference = ALLGATHER_COST
# shard(source) -> target
elif (
len(source_shard_list) == len(target_shard_list) - 1
and source_shard_list == target_shard_list[:-1]
and target_shard_list[-1] not in source_shard_list
):
difference = SHARD_COST
# S1 -> S0 or S0 -> S1
elif len(source_shard_list) == len(target_shard_list):
# source -> R -> target
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
# R -> S01
elif len(source_shard_list) == len(target_shard_list) - 2:
difference = SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> R
elif len(source_shard_list) == len(target_shard_list) + 2:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
# S1 -> S01
elif len(source_shard_list) == len(target_shard_list) - 1:
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> S1
elif len(source_shard_list) == len(target_shard_list) + 1:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
else:
difference = NAN
difference_dict[(source_spec, target_spec)] = difference
return difference_dict
@staticmethod
def _convert_str_to_shard_list(str_spec):
"""
Convert str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
"""
if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpec:
"""
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694) * [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
2024-05-14 05:52:45 +00:00
Sharding spec describes how to shard a tensor with dim_size dimensions. For example for a 3D tensor, the sharding sequence
[R, S0, S1] means not sharding the first dim, sharding the 3rd along the 1st device mesh axis (Process group)
and sharding the 3th dim along the 2nd device mesh axis. Useful for say, 2D Tensor Parallel.
Argument:
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key describe which logical axis will be sharded in that dimension.
sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
"""
def __init__(
self, dim_size: int, dim_partition_dict: Dict[int, List[int]] = None, sharding_sequence: List[DimSpec] = None
):
self.dims = dim_size
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence
if self.sharding_sequence is None:
assert (
self.dim_partition_dict is not None
), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object."
self.dim_partition_dict = merge_same_dim_mesh_list(
dim_size=self.dims, dim_partition_dict=self.dim_partition_dict
)
self.sharding_sequence = self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None:
assert (
self.sharding_sequence is not None
), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object."
self.dim_partition_dict = self.convert_shard_sequence_to_dict()
self._sanity_check()
def _sanity_check(self):
if len(self.sharding_sequence) > self.dims:
raise ShardingOutOfIndexError(
f"sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}."
)
2023-03-08 02:45:31 +00:00
if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims:
raise ShardingOutOfIndexError(
f"the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}."
)
def __repr__(self):
res_list = ["ShardingSpec:"]
res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
return " ".join(res_list)
def convert_dict_to_shard_sequence(self):
"""
Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
"""
sharding_sequence = [DimSpec([])] * self.dims
for dim, shard_list in self.dim_partition_dict.items():
sharding_sequence[dim] = DimSpec(shard_list)
return sharding_sequence
def convert_shard_sequence_to_dict(self):
"""
Convert sharding_sequence into dim_partition_dict.
"""
new_dim_partition_dict = {}
for index, dim_spec in enumerate(self.sharding_sequence):
if not dim_spec.is_replica:
if index not in new_dim_partition_dict:
new_dim_partition_dict[index] = []
new_dim_partition_dict[index].extend(dim_spec.shard_list)
return new_dim_partition_dict
def spec_diff(self, other):
"""
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
pair of sharding sequence.
Example:
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
dim_partition_dict_to_compare = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
Return:
difference(int): Difference between two ShardingSpec.
"""
assert len(self.sharding_sequence) == len(
other.sharding_sequence
), f"Cannot compare difference for two sharding specs with different length."
difference = 0
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
difference += orig_dim_spec.dim_diff(other_dim_spec)
return difference