mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix bugs caused by negative dim key (#1808)
* [autoparallel] fix bugs caused by negative dim key * fix import error * fix matmul test issue * fix unit test issuepull/1857/head
parent
4268ae017b
commit
49216d7ab1
|
@ -454,6 +454,9 @@ class MatMulHandler(NodeHandler):
|
||||||
if -1 in dim_partition_dict:
|
if -1 in dim_partition_dict:
|
||||||
shard = dim_partition_dict.pop(-1)
|
shard = dim_partition_dict.pop(-1)
|
||||||
dim_partition_dict[0] = shard
|
dim_partition_dict[0] = shard
|
||||||
|
if 1 in dim_partition_dict:
|
||||||
|
shard = dim_partition_dict.pop(1)
|
||||||
|
dim_partition_dict[0] = shard
|
||||||
|
|
||||||
# re-init the sharding spec
|
# re-init the sharding spec
|
||||||
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
|
input_sharding_spec.__init__(input_sharding_spec.device_mesh,
|
||||||
|
|
|
@ -9,6 +9,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
ShardingStrategy,
|
ShardingStrategy,
|
||||||
TrainCycleItem,
|
TrainCycleItem,
|
||||||
)
|
)
|
||||||
|
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
|
||||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||||
|
|
||||||
from .strategy_generator import StrategyGenerator
|
from .strategy_generator import StrategyGenerator
|
||||||
|
@ -103,6 +104,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
strategy.memory_cost = memory_cost
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def split_input_channel(self, mesh_dim_0):
|
def split_input_channel(self, mesh_dim_0):
|
||||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
|
@ -134,6 +136,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
|
@ -165,6 +168,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def non_split(self):
|
def non_split(self):
|
||||||
name = f'RR = RR x R'
|
name = f'RR = RR x R'
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
|
@ -186,6 +190,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def split_input_batch(self, mesh_dim_0):
|
def split_input_batch(self, mesh_dim_0):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
|
@ -221,6 +226,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
|
@ -256,6 +262,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||||
dim_partition_dict_mapping = {
|
dim_partition_dict_mapping = {
|
||||||
|
|
|
@ -3,9 +3,12 @@ import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||||
enumerate_all_possible_2d_sharding)
|
enumerate_all_possible_1d_sharding,
|
||||||
|
enumerate_all_possible_2d_sharding,
|
||||||
|
ignore_sharding_exception,
|
||||||
|
)
|
||||||
|
|
||||||
from .strategy_generator import StrategyGenerator
|
from .strategy_generator import StrategyGenerator
|
||||||
|
|
||||||
|
@ -79,6 +82,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
|
||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
strategy.memory_cost = memory_cost
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
def _generate_strategy_with_dim_partition(self, dim_partition):
|
def _generate_strategy_with_dim_partition(self, dim_partition):
|
||||||
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}
|
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from colossalai.tensor.utils import convert_dim_partition_dict
|
||||||
|
|
||||||
|
|
||||||
class StrategyGenerator(ABC):
|
class StrategyGenerator(ABC):
|
||||||
|
@ -74,11 +75,15 @@ class StrategyGenerator(ABC):
|
||||||
op_data = self.op_data[op_data_name]
|
op_data = self.op_data[op_data_name]
|
||||||
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
|
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
|
||||||
sharding_spec = []
|
sharding_spec = []
|
||||||
for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict):
|
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
||||||
|
dim_size = len(logical_shape)
|
||||||
|
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
|
||||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
entire_shape=output.shape,
|
entire_shape=logical_shape,
|
||||||
dim_partition_dict=dim_partition_dict_element)
|
dim_partition_dict=dim_partition_dict_element)
|
||||||
else:
|
else:
|
||||||
|
dim_size = len(op_data.logical_shape)
|
||||||
|
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
|
||||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
entire_shape=op_data.logical_shape,
|
entire_shape=op_data.logical_shape,
|
||||||
dim_partition_dict=dim_partition_dict)
|
dim_partition_dict=dim_partition_dict)
|
||||||
|
|
|
@ -1,19 +1,17 @@
|
||||||
|
from . import distspec
|
||||||
|
from .colo_parameter import ColoParameter
|
||||||
|
from .colo_tensor import ColoTensor
|
||||||
|
from .comm_spec import CollectiveCommPattern, CommSpec
|
||||||
|
from .compute_spec import ComputePattern, ComputeSpec
|
||||||
|
from .dist_spec_mgr import DistSpecManager
|
||||||
|
from .distspec import ReplicaSpec, ShardSpec
|
||||||
|
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||||
from .process_group import ProcessGroup
|
from .process_group import ProcessGroup
|
||||||
from .tensor_spec import ColoTensorSpec
|
from .tensor_spec import ColoTensorSpec
|
||||||
from .distspec import ShardSpec
|
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
|
||||||
from .distspec import ReplicaSpec
|
|
||||||
|
|
||||||
from .compute_spec import ComputeSpec, ComputePattern
|
|
||||||
from .colo_tensor import ColoTensor
|
|
||||||
from .colo_parameter import ColoParameter
|
|
||||||
from .utils import convert_parameter, named_params_with_colotensor
|
|
||||||
from .dist_spec_mgr import DistSpecManager
|
|
||||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
|
||||||
from .comm_spec import CollectiveCommPattern, CommSpec
|
|
||||||
from . import distspec
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
|
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
|
||||||
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern'
|
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.tensor.colo_tensor import ColoTensor
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
from colossalai.tensor.const import TensorType
|
from colossalai.tensor.const import TensorType
|
||||||
from colossalai.tensor import ColoTensorSpec
|
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
|
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||||
|
|
||||||
|
|
||||||
def filter_args(func, *args):
|
def filter_args(func, *args):
|
||||||
|
|
|
@ -4,9 +4,10 @@ from typing import Callable, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
|
|
||||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||||
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
|
||||||
|
from colossalai.tensor.process_group import ProcessGroup
|
||||||
|
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||||
|
|
||||||
from .const import TensorType
|
from .const import TensorType
|
||||||
from .op_wrapper import _COLOSSAL_OPS
|
from .op_wrapper import _COLOSSAL_OPS
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
from colossalai.tensor.distspec import _DistSpec
|
|
||||||
# from colossalai.nn.layer.utils import divide
|
|
||||||
from numpy import prod
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
# from colossalai.nn.layer.utils import divide
|
||||||
|
from numpy import prod
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.tensor import ProcessGroup
|
from colossalai.tensor.distspec import _DistSpec
|
||||||
|
from colossalai.tensor.process_group import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import torch
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Any
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.tensor.colo_tensor import ColoTensor
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
from colossalai.tensor import ColoTensorSpec
|
from colossalai.tensor.tensor_spec import ColoTensorSpec
|
||||||
|
|
||||||
|
|
||||||
class ParamOpHook(ABC):
|
class ParamOpHook(ABC):
|
||||||
|
|
|
@ -6,6 +6,8 @@ import torch
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
from .utils import merge_same_dim_mesh_list
|
||||||
|
|
||||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||||
|
|
||||||
ALLGATHER_COST = 20
|
ALLGATHER_COST = 20
|
||||||
|
@ -181,8 +183,12 @@ class ShardingSpec:
|
||||||
self.dim_partition_dict = dim_partition_dict
|
self.dim_partition_dict = dim_partition_dict
|
||||||
self.sharding_sequence = sharding_sequence
|
self.sharding_sequence = sharding_sequence
|
||||||
if self.sharding_sequence is None:
|
if self.sharding_sequence is None:
|
||||||
|
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
|
||||||
|
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape),
|
||||||
|
dim_partition_dict=self.dim_partition_dict)
|
||||||
self.convert_dict_to_shard_sequence()
|
self.convert_dict_to_shard_sequence()
|
||||||
elif self.dim_partition_dict is None:
|
elif self.dim_partition_dict is None:
|
||||||
|
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
|
||||||
self.convert_shard_sequence_to_dict()
|
self.convert_shard_sequence_to_dict()
|
||||||
self._sanity_check()
|
self._sanity_check()
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
from typing import Optional
|
|
||||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
|
||||||
from .compute_spec import ComputeSpec
|
|
||||||
from colossalai.tensor import ProcessGroup
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||||
|
from colossalai.tensor.process_group import ProcessGroup
|
||||||
|
|
||||||
|
from .compute_spec import ComputeSpec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ColoTensorSpec:
|
class ColoTensorSpec:
|
||||||
""" ColoTensorSpec
|
""" ColoTensorSpec
|
||||||
|
|
||||||
A data class for specifications of the `ColoTensor`.
|
A data class for specifications of the `ColoTensor`.
|
||||||
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
|
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
|
||||||
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
|
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import torch
|
from typing import Dict, Iterator, List, Tuple, Union
|
||||||
|
|
||||||
from typing import Iterator, Tuple, Union
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.tensor.colo_tensor import ColoTensor
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +13,7 @@ def all_gather_simulator(target_pair):
|
||||||
|
|
||||||
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
|
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
|
||||||
Therefore, all gather operation just remove the last element in shard list,
|
Therefore, all gather operation just remove the last element in shard list,
|
||||||
e.g.:
|
e.g.:
|
||||||
all-gather(S01) -> S0
|
all-gather(S01) -> S0
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
|
@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
|
||||||
and simulate the influence of the DimSpec.
|
and simulate the influence of the DimSpec.
|
||||||
|
|
||||||
We BANNED all representations which shard_list in decreasing order,
|
We BANNED all representations which shard_list in decreasing order,
|
||||||
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
|
||||||
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
|
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
|
||||||
Argument:
|
Argument:
|
||||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||||
and the second element decribes which logical axis will be sharded in that dimension.
|
and the second element decribes which logical axis will be sharded in that dimension.
|
||||||
e.g.:
|
e.g.:
|
||||||
all-to-all(S0, S1) -> [S01, R]
|
all-to-all(S0, S1) -> [S01, R]
|
||||||
all-to-all(S0, R) -> [R, S0]
|
all-to-all(S0, R) -> [R, S0]
|
||||||
Otherwise, we extend the front shard_list to behind.
|
Otherwise, we extend the front shard_list to behind.
|
||||||
e.g.:
|
e.g.:
|
||||||
all-to-all(R, S1) -> [S1, R]
|
all-to-all(R, S1) -> [S1, R]
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
|
||||||
and the second element decribes which logical axis will be sharded in that dimension.
|
and the second element decribes which logical axis will be sharded in that dimension.
|
||||||
|
@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
|
||||||
and simulate the influence of the DimSpec.
|
and simulate the influence of the DimSpec.
|
||||||
|
|
||||||
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
|
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
|
||||||
In addition, We BANNED all representations which shard_list in decreasing order,
|
In addition, We BANNED all representations which shard_list in decreasing order,
|
||||||
such as S10, so shard(S0) -> S10 is NOT allowed.
|
such as S10, so shard(S0) -> S10 is NOT allowed.
|
||||||
Therefore, for the R dimension, we could just append any legal sharding dim on it.
|
Therefore, for the R dimension, we could just append any legal sharding dim on it.
|
||||||
e.g.:
|
e.g.:
|
||||||
|
@ -164,3 +165,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
|
||||||
|
|
||||||
# Now we can set the attribute appropriately.
|
# Now we can set the attribute appropriately.
|
||||||
setattr(module, param_name, st)
|
setattr(module, param_name, st)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||||
|
'''
|
||||||
|
This method is used to convert the negative dim value to positive.
|
||||||
|
'''
|
||||||
|
dims_to_convert = []
|
||||||
|
for dim, mesh_list in dim_partition_dict.items():
|
||||||
|
if dim < 0:
|
||||||
|
dims_to_convert.append(dim)
|
||||||
|
for dim in dims_to_convert:
|
||||||
|
dim_partition_dict.pop(dim)
|
||||||
|
dim_partition_dict[dim_size + dim] = mesh_list
|
||||||
|
return dim_partition_dict
|
||||||
|
|
||||||
|
|
||||||
|
def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
|
||||||
|
'''
|
||||||
|
This method is used to merge the different key value which points to same physical position.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
|
||||||
|
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
|
||||||
|
'''
|
||||||
|
converted_dim_partition_dict = {}
|
||||||
|
for dim, mesh_list in dim_partition_dict.items():
|
||||||
|
if dim < 0:
|
||||||
|
dim = dim_size + dim
|
||||||
|
if dim not in converted_dim_partition_dict:
|
||||||
|
converted_dim_partition_dict[dim] = mesh_list
|
||||||
|
else:
|
||||||
|
converted_dim_partition_dict[dim].extend(mesh_list)
|
||||||
|
|
||||||
|
return converted_dim_partition_dict
|
||||||
|
|
Loading…
Reference in New Issue