mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add shard option (#2423)
parent
1b7587d958
commit
41429b9b28
|
@ -11,6 +11,7 @@ from .layer_norm_handler import LayerNormModuleHandler
|
|||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from .matmul_handler import MatMulHandler
|
||||
from .normal_pooling_handler import NormPoolingHandler
|
||||
from .option import ShardOption
|
||||
from .output_handler import OutputHandler
|
||||
from .placeholder_handler import PlaceholderHandler
|
||||
from .registry import operator_registry
|
||||
|
@ -27,5 +28,5 @@ __all__ = [
|
|||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
|
||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
||||
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'
|
||||
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption'
|
||||
]
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
|
@ -35,12 +36,14 @@ class NodeHandler(ABC):
|
|||
node: Node,
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_vector: StrategiesVector,
|
||||
shard_option: ShardOption = ShardOption.STANDARD,
|
||||
) -> None:
|
||||
self.node = node
|
||||
self.predecessor_node = list(node._input_nodes.keys())
|
||||
self.successor_node = list(node.users.keys())
|
||||
self.device_mesh = device_mesh
|
||||
self.strategies_vector = strategies_vector
|
||||
self.shard_option = shard_option
|
||||
|
||||
def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
|
||||
"""
|
||||
|
@ -181,6 +184,21 @@ class NodeHandler(ABC):
|
|||
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
|
||||
check_sharding_spec_validity(sharding_spec, op_data.data)
|
||||
|
||||
remove_strategy_list = []
|
||||
for strategy in self.strategies_vector:
|
||||
shard_level = 0
|
||||
for op_data, sharding_spec in strategy.sharding_specs.items():
|
||||
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
|
||||
for dim, shard_axis in sharding_spec.dim_partition_dict.items():
|
||||
shard_level += len(shard_axis)
|
||||
if self.shard_option == ShardOption.SHARD and shard_level == 0:
|
||||
remove_strategy_list.append(strategy)
|
||||
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
|
||||
remove_strategy_list.append(strategy)
|
||||
|
||||
for strategy in remove_strategy_list:
|
||||
self.strategies_vector.remove(strategy)
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
from enum import Enum
|
||||
|
||||
__all__ = ['ShardOption']
|
||||
|
||||
|
||||
class ShardOption(Enum):
|
||||
"""
|
||||
This enum class is to define the shard level required in node strategies.
|
||||
|
||||
Notes:
|
||||
STANDARD: We do not add any extra shard requirements.
|
||||
SHARD: We require the node to be shard using at least one device mesh axis.
|
||||
FULL_SHARD: We require the node to be shard using all device mesh axes.
|
||||
"""
|
||||
STANDARD = 0
|
||||
SHARD = 1
|
||||
FULL_SHARD = 2
|
|
@ -0,0 +1,112 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input, others, bias=None):
|
||||
x = nn.functional.linear(input, others, bias=bias)
|
||||
return x
|
||||
|
||||
|
||||
def check_shard_option(shard_option):
|
||||
model = LinearModel().cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(4, 4, 4, 16).to('meta'),
|
||||
'others': torch.rand(32, 16).to('meta')
|
||||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
linear_func_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(linear_func_node)
|
||||
|
||||
# build handler
|
||||
handler = LinearFunctionHandler(node=linear_func_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=strategies_vector,
|
||||
shard_option=shard_option)
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||
|
||||
# S01R = S01R x RR
|
||||
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||
assert 'S01R = S01R x RR_1' in strategy_name_list
|
||||
assert 'S01R = S01R x RR_2' in strategy_name_list
|
||||
|
||||
# RR = RS01 x S01R
|
||||
assert 'RR = RS01 x S01R' in strategy_name_list
|
||||
|
||||
# RS01 = RR x RS01
|
||||
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||
|
||||
if shard_option == ShardOption.SHARD:
|
||||
# RR = RS x SR
|
||||
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||
|
||||
# RS= RR x RS
|
||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
if shard_option == ShardOption.STANDARD:
|
||||
# RR = RS x SR
|
||||
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||
|
||||
# RS= RR x RS
|
||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
# RR = RR x RR
|
||||
assert 'RR = RR x RR' in strategy_name_list
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_shard_option():
|
||||
for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD]:
|
||||
check_shard_option(shard_option)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_option()
|
Loading…
Reference in New Issue