mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] handled illegal sharding strategy (#1728)
* [autoparallel] handled illegal sharding strategy * polish codepull/1743/head
parent
cbe9a4cb45
commit
eee84908d4
|
@ -1,13 +1,15 @@
|
|||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
import torch
|
||||
from torch.fx.node import Node
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from typing import Union, Dict, List, Optional
|
||||
import warnings
|
||||
from functools import reduce
|
||||
import functools
|
||||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .constants import INFINITY_COST
|
||||
|
||||
|
||||
|
@ -87,7 +89,7 @@ def generate_resharding_costs(nodes: List[Node],
|
|||
return resharding_costs
|
||||
|
||||
|
||||
def exception_handler(func):
|
||||
def ignore_sharding_exception(func):
|
||||
"""
|
||||
A function wrapper which executes the function with a specified seed.
|
||||
"""
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
|
||||
|
||||
__all__ = ['BatchNormHandler']
|
||||
|
||||
|
@ -110,7 +113,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
|
||||
|
@ -185,7 +188,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
|
@ -226,7 +229,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def non_split(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
|
@ -322,7 +325,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
|
||||
self.strategies_vector.append(new_sharding_strategy)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
|
||||
|
||||
|
@ -363,7 +366,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
|
||||
|
||||
|
@ -404,7 +407,7 @@ class BatchNormHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
|
||||
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['BcastOpHandler']
|
||||
|
||||
|
@ -136,7 +140,7 @@ class BcastOpHandler(OperatorHandler):
|
|||
|
||||
return output_sharding_spec_list
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def _register_strategy(self, output_sharding_spec):
|
||||
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
|
||||
|
@ -171,7 +175,7 @@ class BcastOpHandler(OperatorHandler):
|
|||
##############################################
|
||||
#used to generate strategies for torch.matmul#
|
||||
##############################################
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
|
||||
# this dim partition dict only describes the batch dimensions, but in this scenario,
|
||||
# matrix dimensions are fully replicated, so it do not need extra process.
|
||||
|
@ -210,7 +214,7 @@ class BcastOpHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
|
||||
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
|
||||
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
|
||||
|
@ -268,7 +272,7 @@ class BcastOpHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
|
||||
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
|
||||
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
|
||||
|
@ -332,7 +336,7 @@ class BcastOpHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
|
||||
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
|
||||
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
|
||||
|
@ -398,7 +402,7 @@ class BcastOpHandler(OperatorHandler):
|
|||
self._split_dim_k(dim_partition_dict, mesh_dim_list)
|
||||
self._split_dim_j(dim_partition_dict, mesh_dim_list)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
|
@ -435,7 +439,7 @@ class BcastOpHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
|
@ -474,7 +478,7 @@ class BcastOpHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
|
||||
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
|
||||
|
||||
__all__ = ['ConvHandler']
|
||||
|
||||
|
@ -105,7 +108,7 @@ class ConvHandler(OperatorHandler):
|
|||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
|
@ -153,7 +156,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||
|
||||
|
@ -199,7 +202,7 @@ class ConvHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
|
@ -245,7 +248,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
|
@ -288,7 +291,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
||||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||
|
||||
|
@ -331,7 +334,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_weight_out_channel(self, mesh_dim_0):
|
||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
|
@ -374,7 +377,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def non_split(self):
|
||||
name = f'RR = RR x RR'
|
||||
|
||||
|
@ -415,7 +418,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
|
@ -463,7 +466,7 @@ class ConvHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
import operator
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
||||
from functools import reduce
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
|
||||
from enum import Enum
|
||||
from .strategy_generator import StrategyGenerator, IntermediateStrategy
|
||||
from typing import List
|
||||
from .operator_handler import OperatorHandler
|
||||
from .strategy_generator import IntermediateStrategy, StrategyGenerator
|
||||
|
||||
__all__ = ['DotHandler']
|
||||
|
||||
|
@ -415,7 +418,7 @@ class DotHandler(OperatorHandler):
|
|||
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
|
||||
return compute_cost
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
@ -456,7 +459,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle the case SR = SS x SR
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
@ -496,7 +499,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
|
@ -534,7 +537,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def recompute_split_both_contract(self, mesh_dim):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
|
||||
|
@ -569,7 +572,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_space_only(self, mesh_dim):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
|
||||
|
@ -605,7 +608,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
|
@ -641,7 +644,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
|
@ -678,7 +681,7 @@ class DotHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['EmbeddingHandler']
|
||||
|
||||
|
@ -76,7 +79,7 @@ class EmbeddingHandler(OperatorHandler):
|
|||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
|
@ -117,7 +120,7 @@ class EmbeddingHandler(OperatorHandler):
|
|||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'
|
||||
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
generate_sharding_size, ignore_sharding_exception)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size
|
||||
|
||||
__all__ = ['LayerNormHandler']
|
||||
|
||||
|
@ -149,21 +153,21 @@ class LayerNormHandler(OperatorHandler):
|
|||
|
||||
self.strategies_vector.append(sharding_strategies)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_single_mesh_dim(self, mesh_dim_0):
|
||||
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
|
||||
for dim_partition in dim_partition_list:
|
||||
self._generate_strategy_with_dim_partition(dim_partition)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
|
||||
batch_dimension_length = self.input_data.dim() - self.weight.dim()
|
||||
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
|
||||
for dim_partition in dim_partition_list:
|
||||
self._generate_strategy_with_dim_partition(dim_partition)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def non_split(self):
|
||||
name = f'RR = RR x R'
|
||||
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import colorsys
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from copy import deepcopy
|
||||
import math
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from ..constants import INFINITY_COST
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
|
||||
class ReshapeHandler(OperatorHandler):
|
||||
|
@ -24,7 +27,7 @@ class ReshapeHandler(OperatorHandler):
|
|||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def register_strategy(self):
|
||||
# TODO: add strategies with more output sharding specs other than only fully replicated.
|
||||
input_node = self.strategies_vector.predecessor_nodes[0]
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import INFINITY_COST
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
|
||||
INFINITY_COST
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
import math
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
|
||||
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['UnaryElementwiseHandler']
|
||||
|
||||
|
@ -40,7 +44,7 @@ class UnaryElementwiseHandler(OperatorHandler):
|
|||
def _generate_compute_cost(self, *args, **kwargs):
|
||||
return super()._generate_compute_cost(*args, **kwargs)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def register_strategy(self):
|
||||
# TODO: integrate element-wise func and module together
|
||||
# create sharding strategy for element-wise function
|
||||
|
|
|
@ -6,12 +6,10 @@ from typing import Dict, List
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
exception_handler,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception)
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
@ -146,7 +144,7 @@ class WhereHandler(OperatorHandler):
|
|||
|
||||
return output_sharding_spec_list
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def _register_strategy(self, output_sharding_spec):
|
||||
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
|
||||
sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)
|
||||
|
|
|
@ -5,7 +5,8 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
|
||||
from colossalai.tensor.sharding_spec import ShardingException
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
||||
|
||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
|
@ -15,6 +16,100 @@ from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyG
|
|||
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
|
||||
|
||||
|
||||
def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
|
||||
weight_name: str) -> ShardingStrategy:
|
||||
"""
|
||||
This function is a helper function used by both module node handler and function node handler. This function will
|
||||
convert the sharding spec for the transposed weight to the correct partititon spec.
|
||||
|
||||
Args:
|
||||
strategy (ShardingStrategy): the strategy generated by the strategy generator.
|
||||
weight_name (str): the name of the OperationData object for the weight.
|
||||
"""
|
||||
# switch the dimensions of the transposed weight
|
||||
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
|
||||
op_data = strategy.get_op_data_by_name(weight_name)
|
||||
assert op_data.logical_shape != op_data.data.shape, \
|
||||
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
|
||||
switch_partition_dim(sharding_spec, 0, -1)
|
||||
return strategy
|
||||
|
||||
|
||||
def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str,
|
||||
output_name: str) -> List[ShardingStrategy]:
|
||||
"""
|
||||
This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output
|
||||
should have the same sharding spec.
|
||||
|
||||
Args:
|
||||
strategy (ShardingStrategy): the logical strategy generated by the strategy generator.
|
||||
input_name (str): the name of the OperationData object for the input.
|
||||
output_name (str): the name of the OperationData object for the output.
|
||||
|
||||
|
||||
"""
|
||||
# the result will be a list of strategies
|
||||
sharding_strategies = []
|
||||
|
||||
# get operation data
|
||||
input_op_data = strategy.get_op_data_by_name(input_name)
|
||||
output_op_data = strategy.get_op_data_by_name(output_name)
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||
|
||||
# get logger for debug message
|
||||
logger = get_dist_logger()
|
||||
|
||||
# for the input of the linear operation, it can be multi-dimensional. The sharding spec generated is only
|
||||
# 2D, where the first dimension is non-matrix dimension and the last dimension is the matrix dimension.
|
||||
# the logical non-matrix dimension can belong to the 0th to (N-1)th dimension of the physical input shape.
|
||||
# Thus, we enumerate to get all possible cases.
|
||||
if 0 in input_sharding_spec.dim_partition_dict:
|
||||
# if 0 is in the dim_partition_dict, it means that the
|
||||
# the generated sharding strategy does shard the non-matrix dimension,
|
||||
# in this case, we need to do enumeration
|
||||
num_input_dims = input_op_data.data.dim()
|
||||
for i in range(num_input_dims - 1):
|
||||
strategy_copy = strategy.clone()
|
||||
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
|
||||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
try:
|
||||
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
sharding_strategies.append(strategy_copy)
|
||||
except ShardingNotDivisibleError as e:
|
||||
logger.debug(
|
||||
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
|
||||
)
|
||||
else:
|
||||
# the generated sharding strategy does not shard the non-matrix dimension,
|
||||
# in this case, we don't need to do enumeration
|
||||
# but instead, we still need to convert the logical shape to physical shape
|
||||
strategy_copy = strategy.clone()
|
||||
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
|
||||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
|
||||
# after updating, the logical shape will be replaced by the physical shape
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={},
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
print(input_op_data.data.shape)
|
||||
print(output_op_data.data.shape)
|
||||
sharding_strategies.append(strategy_copy)
|
||||
return sharding_strategies
|
||||
|
||||
|
||||
@operator_registry.register(torch.nn.Linear)
|
||||
class LinearModuleHandler(ModuleHandler):
|
||||
"""
|
||||
|
@ -58,44 +153,20 @@ class LinearModuleHandler(ModuleHandler):
|
|||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
"""
|
||||
Convert the sharding spec from the logical shape to the physical shape.
|
||||
Convert the sharding spec from the logical shape to the physical shape. In this function, two tasks are completed:
|
||||
1. the sharding spec is updated for the transposed weight
|
||||
2. the input and output sharding specs are updated to physical shape.
|
||||
"""
|
||||
# switch the dimensions of the transposed weight
|
||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||
if op_data.name == "weight":
|
||||
assert op_data.logical_shape != op_data.data.shape
|
||||
switch_partition_dim(sharding_spec, 0, -1)
|
||||
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight')
|
||||
|
||||
# create multiple sharding strategies for the inputs
|
||||
# as input can be multi-dimensinal and the partition dim is only 2D,
|
||||
# we need to map the partition at dim 0 to one of the first few dimensions of the input
|
||||
sharding_strategies = []
|
||||
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
|
||||
output_op_data = strategy.get_op_data_by_name(str(self.node))
|
||||
num_input_dims = input_op_data.data.dim()
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||
|
||||
if 0 in input_sharding_spec.dim_partition_dict:
|
||||
for i in range(num_input_dims - 1):
|
||||
new_strategy = strategy.clone()
|
||||
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
|
||||
try:
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
sharding_strategies.append(new_strategy)
|
||||
except ShardingException:
|
||||
pass
|
||||
else:
|
||||
sharding_strategies.append(strategy)
|
||||
|
||||
return sharding_strategies
|
||||
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
|
||||
input_name=str(self.node.args[0]),
|
||||
output_name=str(self.node))
|
||||
return strategies
|
||||
|
||||
|
||||
@operator_registry.register(F.linear)
|
||||
|
@ -113,9 +184,12 @@ class LinearFunctionHandler(NodeHandler):
|
|||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
input_meta_data = self.node.args[0]._meta_data
|
||||
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||
type=OperationDataType.ARG,
|
||||
data=self.node.args[0]._meta_data)
|
||||
data=self.node.args[0]._meta_data,
|
||||
logical_shape=input_logical_shape)
|
||||
|
||||
# check if the other operand is a parameter
|
||||
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
|
||||
|
@ -144,44 +218,17 @@ class LinearFunctionHandler(NodeHandler):
|
|||
return mapping
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy):
|
||||
"""
|
||||
Convert the sharding spec of the weight parameter back to its original shape.
|
||||
"""
|
||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||
if op_data.name == str(self.node.args[1]):
|
||||
assert op_data.logical_shape != op_data.data.shape
|
||||
switch_partition_dim(sharding_spec, 0, -1)
|
||||
# switch the dimensions of the transposed weight
|
||||
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
|
||||
weight_name=str(self.node.args[1]))
|
||||
|
||||
# create multiple sharding strategies for the inputs
|
||||
# as input can be multi-dimensinal and the partition dim is only 2D,
|
||||
# we need to map the partition at dim 0 to one of the first few dimensions of the input
|
||||
sharding_strategies = []
|
||||
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
|
||||
output_op_data = strategy.get_op_data_by_name(str(self.node))
|
||||
num_input_dims = input_op_data.data.dim()
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||
|
||||
if 0 in input_sharding_spec.dim_partition_dict:
|
||||
for i in range(num_input_dims - 1):
|
||||
new_strategy = strategy.clone()
|
||||
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
|
||||
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
|
||||
try:
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
sharding_strategies.append(new_strategy)
|
||||
except ShardingException:
|
||||
pass
|
||||
else:
|
||||
sharding_strategies.append(strategy)
|
||||
|
||||
return strategy
|
||||
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
|
||||
input_name=str(self.node.args[0]),
|
||||
output_name=str(self.node))
|
||||
return strategies
|
||||
|
||||
|
||||
@operator_registry.register(torch.bmm)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
@ -292,7 +293,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def generate(self):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
'''
|
||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
||||
'''
|
||||
|
@ -325,9 +326,4 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
# S01R = S01R x R WITH SYNC_BN
|
||||
# strategy_list.append(self.split_input_batch_1d(0, 1))
|
||||
|
||||
for strategy in strategy_list:
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
|
|
@ -5,7 +5,8 @@ from functools import reduce
|
|||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import exception_handler
|
||||
from colossalai.auto_parallel.tensor_shard.utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
@ -25,8 +26,8 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||
'''
|
||||
input_op_data = self.op_data['input']
|
||||
assert input_op_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
assert input_op_data.data.dim() in (
|
||||
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy):
|
||||
'''
|
||||
|
@ -99,7 +100,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
|
@ -146,7 +147,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_batch(self, mesh_dim_0):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
|
||||
|
||||
|
@ -183,7 +184,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
|
@ -230,7 +231,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
|
@ -270,7 +271,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
|
||||
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
|
||||
|
||||
|
@ -301,7 +302,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_weight_out_channel(self, mesh_dim_0):
|
||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||
|
||||
|
@ -334,7 +335,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def non_split(self):
|
||||
name = f'RR = RR x RR'
|
||||
|
||||
|
@ -353,7 +354,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping={})
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
|
||||
|
@ -391,7 +392,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
dim_partition_dict_mapping = {
|
||||
|
@ -421,7 +422,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
dim_partition_dict_mapping = {
|
||||
|
@ -453,7 +454,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def generate(self) -> List[ShardingStrategy]:
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategies = []
|
||||
# SS = SR x RS
|
||||
strategies.append(self.split_input_batch_weight_out_channel(0, 1))
|
||||
|
@ -491,20 +492,4 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
# RS01 = RR x RS01
|
||||
strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
|
||||
|
||||
rm_list = [strategy for strategy in strategies if strategy is None]
|
||||
for rm_element in rm_list:
|
||||
strategies.remove(rm_element)
|
||||
illegal_strategy_list = []
|
||||
# update mete info on cost
|
||||
for strategy in strategies:
|
||||
try:
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
except AssertionError as e:
|
||||
illegal_strategy_list.append(strategy)
|
||||
warnings.warn(f'{e}')
|
||||
for strategy in illegal_strategy_list:
|
||||
strategies.remove(strategy)
|
||||
|
||||
return strategies
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
@ -61,7 +62,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
|
|||
Deal with case 1 and 2.
|
||||
'''
|
||||
|
||||
def generate(self):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
for strategy in self.predecessor_node.strategies_vector:
|
||||
dim_partition_dict_mapping = {}
|
||||
|
@ -109,7 +110,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
|
|||
Deal with case 3.
|
||||
'''
|
||||
|
||||
def generate(self):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
index = self.op_data["index"].data
|
||||
|
||||
|
@ -133,9 +134,4 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
|
|||
|
||||
strategy_list.append(strategy)
|
||||
|
||||
for strategy in strategy_list:
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
||||
|
@ -159,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def generate(self):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
'''
|
||||
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
|
||||
'''
|
||||
|
@ -178,11 +179,5 @@ class LayerNormGenerator(StrategyGenerator):
|
|||
|
||||
# RR = RR x R
|
||||
strategy_list.append(self.non_split())
|
||||
# update mete info on cost
|
||||
|
||||
for strategy in strategy_list:
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
|
|
@ -3,6 +3,8 @@ from functools import reduce
|
|||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import \
|
||||
ignore_sharding_exception
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
@ -169,7 +171,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
strategy.compute_cost = compute_cost
|
||||
|
||||
def generate(self) -> List[ShardingStrategy]:
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategies = []
|
||||
|
||||
# SS = SR x RS
|
||||
|
@ -201,14 +203,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
# RS01 = RR x RS01
|
||||
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
||||
|
||||
# update mete info on cost
|
||||
for strategy in strategies:
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return strategies
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
@ -249,6 +246,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle the case SR = SS x SR
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
@ -289,6 +287,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
|
@ -324,6 +323,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def recompute_split_both_contract(self, mesh_dim):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
|
||||
|
@ -351,6 +351,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_space_only(self, mesh_dim):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
|
||||
|
@ -380,6 +381,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||
# get sharding spec
|
||||
|
@ -410,6 +412,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communcation_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||
|
||||
|
@ -437,6 +440,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||
|
||||
|
@ -542,7 +546,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec=sharding_spec_mapping['bias'],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||
communication_action_mappingp['bias'] = bias_comm_spec
|
||||
communication_action_mapping['bias'] = bias_comm_spec
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@ -662,7 +666,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
def generate(self) -> List[ShardingStrategy]:
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
device_mesh_is_1d = True
|
||||
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
|
||||
|
|
|
@ -25,8 +25,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
|
|||
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
|
||||
'''
|
||||
input_op_data = self.op_data['input']
|
||||
assert input_op_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
|
||||
assert input_op_data.data.dim() in (
|
||||
3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
|
||||
'''
|
||||
|
@ -103,7 +103,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
|
|||
|
||||
return dim_partition_list
|
||||
|
||||
def generate(self) -> List[ShardingStrategy]:
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
|
||||
dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1)
|
||||
|
@ -111,9 +111,4 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
|
|||
strategy = self._generate_strategy_with_dim_partition(dim_partition)
|
||||
strategy_list.append(strategy)
|
||||
|
||||
for strategy in strategy_list:
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
|
||||
from .strategy_generator import OutputStrategyGenerator
|
||||
|
@ -30,7 +32,7 @@ class OutputGenerator(OutputStrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def generate(self):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
dim_partition_dict_mapping = {
|
||||
"output": {},
|
||||
}
|
||||
|
@ -47,8 +49,4 @@ class OutputGenerator(OutputStrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return [strategy]
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
@ -35,7 +37,7 @@ class PlaceholderGenerator(StrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def generate(self):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
dim_partition_dict_mapping = {
|
||||
"output": {},
|
||||
}
|
||||
|
@ -48,8 +50,4 @@ class PlaceholderGenerator(StrategyGenerator):
|
|||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return [strategy]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
@ -49,7 +50,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def generate(self):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
# For reshape function, to keep the computing correctness we keep the sharding
|
||||
# spec of input is fully replicated. In addition, we will keep the output in
|
||||
|
|
|
@ -4,13 +4,12 @@ from functools import reduce
|
|||
from typing import Any, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
|
||||
TrainCycleItem)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from torch.fx import Node
|
||||
|
||||
|
||||
class StrategyGenerator(ABC):
|
||||
|
@ -24,6 +23,9 @@ class StrategyGenerator(ABC):
|
|||
self.op_data = operation_data_mapping
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
# validate the whether operation data is of desired value
|
||||
self.validate()
|
||||
|
||||
@property
|
||||
def has_bias(self):
|
||||
"""
|
||||
|
@ -102,9 +104,9 @@ class StrategyGenerator(ABC):
|
|||
|
||||
comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||
|
||||
def _compute_and_add(data: OperationData, comm_spec: CommSpec):
|
||||
def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
|
||||
num_ele_in_comm = comm_spec.get_comm_cost()
|
||||
dtype = operand.data.dtype
|
||||
dtype = op_data.data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
for phase, cost in num_ele_in_comm.items():
|
||||
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
|
||||
|
@ -151,11 +153,30 @@ class StrategyGenerator(ABC):
|
|||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
|
||||
|
||||
@abstractmethod
|
||||
def generate(self) -> List[ShardingStrategy]:
|
||||
"""
|
||||
Generate all possible sharding strategies for this operation.
|
||||
"""
|
||||
strategies = self.collate_strategies()
|
||||
|
||||
# some strategies may be None as ignore_sharding_exception may return None
|
||||
# when ShardingSpecException occurs.
|
||||
# thus, remove those None values
|
||||
strategies = [strategy for strategy in strategies if strategy]
|
||||
|
||||
# update the costs
|
||||
# update mete info on cost
|
||||
# these update methods are all in-place, the default method will do nothing
|
||||
# the cost info will only be added if the child class overrides these methods
|
||||
for strategy in strategies:
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return strategies
|
||||
|
||||
@abstractmethod
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
|
||||
|
@ -48,7 +49,7 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
|
|||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def generate(self):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
# For element-wise function, we keep the sharding spec of output node same as
|
||||
# the input. Therefore, the different strategies of input node with same
|
||||
|
@ -73,9 +74,4 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
|
|||
communication_action_mapping=communication_action_mapping)
|
||||
strategy_list.append(strategy)
|
||||
|
||||
for strategy in strategy_list:
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
||||
|
@ -78,7 +79,7 @@ class WhereGenerator(StrategyGenerator):
|
|||
|
||||
return dim_partition_list
|
||||
|
||||
def generate(self):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
'''
|
||||
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
|
||||
'''
|
||||
|
@ -90,9 +91,4 @@ class WhereGenerator(StrategyGenerator):
|
|||
strategy = self._generate_strategy_with_dim_partition(dim_partition)
|
||||
strategy_list.append(strategy)
|
||||
|
||||
for strategy in strategy_list:
|
||||
self.update_communication_cost(strategy)
|
||||
self.update_compute_cost(strategy)
|
||||
self.update_memory_cost(strategy)
|
||||
|
||||
return strategy_list
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
|
||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||
from .misc import exception_handler
|
||||
from .misc import ignore_sharding_exception
|
||||
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
|
||||
switch_partition_dim, update_partition_dim)
|
||||
|
||||
__all__ = [
|
||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'switch_partition_dim',
|
||||
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
|
||||
'generate_sharding_size'
|
||||
]
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
import functools
|
||||
import warnings
|
||||
|
||||
__all__ = ['exception_handler']
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
|
||||
__all__ = ['ignore_sharding_exception']
|
||||
|
||||
|
||||
def exception_handler(func):
|
||||
def ignore_sharding_exception(func):
|
||||
"""
|
||||
A function wrapper to handle the AssertionError in the function.
|
||||
A function wrapper to handle the ShardingSpecException in the function.
|
||||
If ShardingSpecException occurs, this function will return None.
|
||||
|
||||
Usage:
|
||||
# mute the assertion error in the function
|
||||
@exception_handler
|
||||
@ignore_sharding_exception
|
||||
def do_something():
|
||||
...
|
||||
"""
|
||||
|
@ -18,9 +21,11 @@ def exception_handler(func):
|
|||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
logger = get_dist_logger()
|
||||
rst = func(*args, **kwargs)
|
||||
return rst
|
||||
except AssertionError as e:
|
||||
warnings.warn(f'{e}')
|
||||
except ShardingSpecException as e:
|
||||
logger.debug(e)
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import torch
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
import operator
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
|
||||
|
||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
|
||||
|
@ -138,7 +140,19 @@ class _DimSpec:
|
|||
return difference
|
||||
|
||||
|
||||
class ShardingException(Exception):
|
||||
class ShardingSpecException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ShardingOutOfIndexError(ShardingSpecException):
|
||||
pass
|
||||
|
||||
|
||||
class DuplicatedShardingDimensionError(ShardingSpecException):
|
||||
pass
|
||||
|
||||
|
||||
class ShardingNotDivisibleError(ShardingSpecException):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -156,7 +170,11 @@ class ShardingSpec:
|
|||
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
|
||||
'''
|
||||
|
||||
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
|
||||
def __init__(self,
|
||||
device_mesh: DeviceMesh,
|
||||
entire_shape: torch.Size,
|
||||
dim_partition_dict=None,
|
||||
sharding_sequence=None):
|
||||
self.device_mesh = device_mesh
|
||||
self.entire_shape = entire_shape
|
||||
self.dim_partition_dict = dim_partition_dict
|
||||
|
@ -174,19 +192,36 @@ class ShardingSpec:
|
|||
return ' '.join(res_list)
|
||||
|
||||
def _sanity_check(self):
|
||||
'''
|
||||
In sanity check, we need make sure all axes in logical device mesh only be used
|
||||
once.
|
||||
'''
|
||||
dim_check_list = [i for i in range(self.device_mesh.logical_mesh_id.dim())]
|
||||
# make sure all axes in logical device mesh only be used once
|
||||
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
for element in shard_list:
|
||||
if element in dim_check_list:
|
||||
dim_check_list.remove(element)
|
||||
else:
|
||||
raise ValueError(
|
||||
raise DuplicatedShardingDimensionError(
|
||||
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
|
||||
|
||||
# make sure that the dimension is not out of index
|
||||
for dim in self.dim_partition_dict.keys():
|
||||
if dim >= len(self.entire_shape):
|
||||
raise ShardingOutOfIndexError(
|
||||
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
|
||||
)
|
||||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in self.dim_partition_dict.items():
|
||||
tensor_dim_size = self.entire_shape[dim]
|
||||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
num_devices *= self.device_mesh.mesh_shape[element]
|
||||
|
||||
if tensor_dim_size % num_devices != 0:
|
||||
raise ShardingNotDivisibleError(
|
||||
f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
|
||||
)
|
||||
|
||||
def convert_dict_to_shard_sequence(self):
|
||||
'''
|
||||
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
from cProfile import run
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
@ -27,6 +30,7 @@ class ConvModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class MatmulModel(nn.Module):
|
||||
|
@ -20,6 +21,7 @@ class MatmulModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
import torch
|
||||
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
||||
"""
|
||||
This function checks whether the ShardingSpec is valid for the physical tensor.
|
||||
This check includes 2 items:
|
||||
1. the sharding spec covers all dimensions of the physical tensor
|
||||
2. the sharding spec for each dimension is divisible by the number of devices.
|
||||
#
|
||||
"""
|
||||
# make sure all dims are covered in sharding spec
|
||||
sharding_len = len(sharding_spec.sharding_sequence)
|
||||
tensor_num_dim = tensor.dim()
|
||||
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
|
||||
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
|
||||
assert sharding_len == tensor_num_dim, \
|
||||
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
|
||||
|
||||
# make sure the sharding is valid for each dim
|
||||
for i in range(tensor_num_dim):
|
||||
dim_size = tensor.shape[i]
|
||||
dim_spec = sharding_spec.sharding_sequence[i]
|
||||
|
||||
if str(dim_spec).startswith('S'):
|
||||
devices_str = str(dim_spec).lstrip('S')
|
||||
num_devices = 1
|
||||
|
||||
if '0' in devices_str:
|
||||
num_devices *= num_devices_in_col
|
||||
if '1' in devices_str:
|
||||
num_devices *= num_devices_in_row
|
||||
|
||||
assert dim_size >= num_devices and dim_size % num_devices == 0, \
|
||||
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
|
@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
|
|||
StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \
|
||||
is_sharding_spec_valid
|
||||
|
||||
|
||||
def test_linear_module_handler():
|
||||
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
@ -91,6 +92,12 @@ def test_linear_module_handler():
|
|||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
|
||||
|
||||
# make sure the sharding spec is valid
|
||||
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
|
||||
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight'))
|
||||
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias'))
|
||||
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
|
@ -101,7 +108,7 @@ def test_linear_module_handler():
|
|||
def test_linear_function_handler():
|
||||
model = nn.Linear(16, 32).to('meta')
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
|
@ -117,11 +124,13 @@ def test_linear_function_handler():
|
|||
# # check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
print(mapping['input'].logical_shape)
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 16])
|
||||
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 16])
|
||||
assert mapping['input'].logical_shape == torch.Size([16, 16])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
|
@ -137,7 +146,7 @@ def test_linear_function_handler():
|
|||
|
||||
assert mapping['output'].name == "linear"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 32])
|
||||
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
|
@ -167,11 +176,18 @@ def test_linear_function_handler():
|
|||
|
||||
for strategy in strategies_vector:
|
||||
strategy: ShardingStrategy
|
||||
print(strategy)
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
|
||||
|
||||
# make sure the sharding spec is valid
|
||||
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
|
||||
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight'))
|
||||
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias'))
|
||||
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
from functools import partial
|
||||
from lib2to3 import pgen2
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.nn._ops._utils import gather_forward_split_backward
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor import ColoTensor, ColoParameter, ProcessGroup
|
||||
from colossalai.nn._ops._utils import gather_forward_split_backward
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@ -18,7 +20,7 @@ def run_dist(rank, world_size, port):
|
|||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# create mlp vars
|
||||
x = ColoTensor.from_torch_tensor(torch.rand(2, 4, 8, requires_grad=True)).cuda()
|
||||
x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
|
||||
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
|
||||
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
|
||||
def test_sharding_spec():
|
||||
|
@ -11,7 +12,7 @@ def test_sharding_spec():
|
|||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 8, 6))
|
||||
entire_shape = torch.Size((16, 8, 6))
|
||||
dim_partition_dict = {0: [0, 1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
|
|
Loading…
Reference in New Issue