From f6c6a932b8e0f921b2ee9623d2969f30ed614077 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Sun, 9 Oct 2022 14:49:18 +0800 Subject: [PATCH] [autoparallel] add following node generator (#1673) * [autoparallel] add following node generator * polish code * polish code * update name of arguments --- .../solver/op_handler/getitem_handler.py | 39 +++++ .../auto_parallel/solver/sharding_strategy.py | 13 +- .../auto_parallel/solver/strategy/__init__.py | 4 +- .../solver/strategy/batch_norm_generator.py | 11 +- .../strategy/conv_strategy_generator.py | 6 +- .../solver/strategy/getitem_generator.py | 147 ++++++++++++++++++ .../solver/strategy/strategy_generator.py | 20 ++- .../test_node_handler/test_conv_handler_v2.py | 2 +- .../test_node_handler/test_getitem_handler.py | 85 ++++++++++ 9 files changed, 310 insertions(+), 17 deletions(-) create mode 100644 colossalai/auto_parallel/solver/op_handler/getitem_handler.py create mode 100644 colossalai/auto_parallel/solver/strategy/getitem_generator.py create mode 100644 tests/test_auto_parallel/test_node_handler/test_getitem_handler.py diff --git a/colossalai/auto_parallel/solver/op_handler/getitem_handler.py b/colossalai/auto_parallel/solver/op_handler/getitem_handler.py new file mode 100644 index 000000000..71022ccdd --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/getitem_handler.py @@ -0,0 +1,39 @@ +import torch +from .node_handler import NodeHandler +from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector +from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator_V2 +from typing import List, Dict +from .registry import operator_registry +import operator + +__all__ = ['GetItemHandler'] + + +@operator_registry.register(operator.getitem) +class GetItemHandler(NodeHandler): + """ + A GetItemHandler which deals with the sharding strategies for operator.getitem. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + if isinstance(op_data_mapping["input"].data, torch.Tensor): + generators.append(TensorStrategyGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + else: + generators.append(TensorTupleStrategyGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_input_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1]) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + + mapping = {"input": physical_input_operand, "index": physical_other_operand, "output": physical_output} + + return mapping diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index b81c25ffd..d101589b2 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -63,24 +63,27 @@ class OperationData: Args: name (str): the name of the operation-related data type (OperationDataType): the type of the operation data - data (torch.Tensor): the value for this data, usually it is a meta tensor. + data (Any): the value for this data, usually it is a meta tensor. logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory. """ name: str type: OperationDataType - data: torch.Tensor + data: Any logical_shape: Tuple[int] = None def __post_init__(self): # if no logical shape is specified, use the data shape as the logical shape - if self.logical_shape is None: + if self.logical_shape is None and isinstance(self.data, torch.Tensor): self.logical_shape = self.data.shape def __repr__(self) -> str: return f'OperationData(name={self.name}, type={self.type})' + def __eq__(self, other) -> bool: + return other.name == self.name + def __hash__(self) -> int: - return hash(f'{self.name}-{self.type}') + return hash(f'{self.name}') @dataclass @@ -123,7 +126,7 @@ class ShardingStrategy_V2: strategy.(default to None) """ name: str - sharding_specs: Dict[OperationData, ShardingSpec] = None + sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None compute_cost: TrainCycleItem = None communication_cost: TrainCycleItem = None memory_cost: TrainCycleItem = None diff --git a/colossalai/auto_parallel/solver/strategy/__init__.py b/colossalai/auto_parallel/solver/strategy/__init__.py index 823a472f8..9881e0512 100644 --- a/colossalai/auto_parallel/solver/strategy/__init__.py +++ b/colossalai/auto_parallel/solver/strategy/__init__.py @@ -2,10 +2,12 @@ from .strategy_generator import StrategyGenerator_V2 from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator from .conv_strategy_generator import ConvStrategyGenerator from .batch_norm_generator import BatchNormStrategyGenerator +from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .layer_norm_generator import LayerNormGenerator __all__ = [ 'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', - 'BatchNormStrategyGenerator', 'LayerNormGenerator' + 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', + 'LayerNormGenerator' ] diff --git a/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py b/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py index 8e9a16c55..a89517004 100644 --- a/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py @@ -86,12 +86,12 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2): # compute bwd cost incurred # bwd_cost = input_grad + other_grad + bias_grad bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) - bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) - bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_activation_cost) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_activation_cost) + parameter=fwd_parameter_cost + bwd_parameter_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -288,4 +288,9 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2): # S01R = S01R x R WITH SYNC_BN strategy_list.append(self.split_input_batch_1d(0, 1)) + for strategy in strategy_list: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + return strategy_list diff --git a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py index bddfc6b65..ef989f92c 100644 --- a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py @@ -91,12 +91,12 @@ class ConvStrategyGenerator(StrategyGenerator_V2): # compute bwd cost incurred # bwd_cost = input_grad + other_grad + bias_grad bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) - bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) - bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_activation_cost) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_activation_cost) + parameter=fwd_parameter_cost + bwd_parameter_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost diff --git a/colossalai/auto_parallel/solver/strategy/getitem_generator.py b/colossalai/auto_parallel/solver/strategy/getitem_generator.py new file mode 100644 index 000000000..43f2eb550 --- /dev/null +++ b/colossalai/auto_parallel/solver/strategy/getitem_generator.py @@ -0,0 +1,147 @@ +import operator +from functools import reduce +from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from .strategy_generator import FollowingStrategyGenerator +from typing import List +from .._utils import exception_handler +import copy + +__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator'] + + +class GetItemStrategyGenerator(FollowingStrategyGenerator): + """ + GetItemStrategyGenerator is a generic class to generate strategies for operator.getitem. + The operation data is defined as `output = input[other]`. + + There are mainly three use cases: + 1. args_0._meta_data: torch.Tensor, args_1._meta_data: int + 2. args_0._meta_data: torch.Tensor, args_1._meta_data: slice + 3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int + """ + + @property + def has_bias(self): + return 'bias' in self.op_data + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + return TrainCycleItem(fwd=10, bwd=10, total=20) + + def update_memory_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + return super().update_memory_cost(strategy) + + +class TensorStrategyGenerator(GetItemStrategyGenerator): + ''' + Deal with case 1 and 2. + ''' + + def generate(self): + strategy_list = [] + for strategy in self.predecessor_node.strategies_vector: + dim_partition_dict_mapping = {} + communication_action_mapping = {} + dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict + dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) + gather_input = 0 in dim_partition_dict_for_input + if gather_input: + logical_process_axis = dim_partition_dict_for_output.pop(0) + + shift_dim_partition_dict_for_output = {} + for dim, mesh_dim_list in dim_partition_dict_for_output.items(): + shift_dim_partition_dict_for_output[dim - 1] = mesh_dim_list + dim_partition_dict_for_output = shift_dim_partition_dict_for_output + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + if gather_input: + input_communication_spec = self.get_communication_spec( + sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=logical_process_axis) + communication_action_mapping["input"] = input_communication_spec + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + strategy_list.append(strategy) + + for strategy in strategy_list: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategy_list + + +class TensorTupleStrategyGenerator(GetItemStrategyGenerator): + ''' + Deal with case 3. + ''' + + def generate(self): + strategy_list = [] + index = self.op_data["index"].data + + for strategy in self.predecessor_node.strategies_vector: + # the sharding spec for input in this case is a tuple of ShardingSpec. + sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]] + dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict + dim_partition_dict_mapping = {} + communication_action_mapping = {} + dim_partition_dict_mapping = { + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + sharding_spec_mapping["input"] = sharding_spec_for_input + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + strategy_list.append(strategy) + + for strategy in strategy_list: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/solver/strategy/strategy_generator.py b/colossalai/auto_parallel/solver/strategy/strategy_generator.py index 0cdc8b018..d44d86ad3 100644 --- a/colossalai/auto_parallel/solver/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/strategy_generator.py @@ -8,6 +8,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh from typing import Dict, List, Union, Any from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem, OperationDataType +from torch.fx import Node +import copy class StrategyGenerator_V2(ABC): @@ -72,10 +74,6 @@ class StrategyGenerator_V2(ABC): """ A factory method to produce a CommSpec object. """ - # use flatten device mesh the same action is applied to two axes - if isinstance(logical_process_axis, list) and len(logical_process_axis) == 2: - sharding_spec.device_mesh = sharding_spec.device_mesh.flatten() - logical_process_axis = 0 return CommSpec(comm_pattern=communication_pattern, sharding_spec=sharding_spec, logical_process_axis=logical_process_axis) @@ -150,3 +148,17 @@ class StrategyGenerator_V2(ABC): If True, means this generator can be used for the current operation. """ pass + + +class FollowingStrategyGenerator(StrategyGenerator_V2): + """ + FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node. + + TODO: remove the original strategy_generator.py after refactoring + """ + + def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, + predecessor_node: Node): + self.op_data = operation_data_mapping + self.device_mesh = device_mesh + self.predecessor_node = predecessor_node diff --git a/tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py index c974fd34e..8fb21e91d 100644 --- a/tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py +++ b/tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py @@ -165,7 +165,7 @@ def test_conv_function_handler(): assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) assert mapping['output'].type == OperationDataType.OUTPUT - strategies_vector = handler.register_strategy() + handler.register_strategy() strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS diff --git a/tests/test_auto_parallel/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_node_handler/test_getitem_handler.py new file mode 100644 index 000000000..c4ef16fc1 --- /dev/null +++ b/tests/test_auto_parallel/test_node_handler/test_getitem_handler.py @@ -0,0 +1,85 @@ +from colossalai.fx.tracer.meta_patch.patched_module import linear +import torch +import torch.nn as nn +from colossalai.fx import ColoTracer, ColoGraphModule +from colossalai.auto_parallel.solver.op_handler.getitem_handler import GetItemHandler +from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler +from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +class GetItemModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other) + x = conv_node[1] + return x + + +def test_getitem_function_handler(): + model = GetItemModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%conv2d, 1), kwargs = {}) + # return getitem + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(4, 4, 64, 64).to('meta'), + "other": torch.rand(4, 16, 3, 3).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[2] + getitem_mod_node = list(graph.nodes)[3] + getitem_strategies_vector = StrategiesVector(getitem_mod_node) + conv_strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + conv_handler = ConvFunctionHandler(node=conv_mod_node, + device_mesh=device_mesh, + strategies_vector=conv_strategies_vector) + conv_handler.register_strategy() + setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + getitem_handler = GetItemHandler(node=getitem_mod_node, + device_mesh=device_mesh, + strategies_vector=getitem_strategies_vector) + + getitem_handler.register_strategy() + # check operation data mapping + mapping = getitem_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "conv2d" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62]) + + assert mapping['index'].name == "index" + assert isinstance(mapping['index'].data, int) + assert mapping['index'].type == OperationDataType.ARG + + assert mapping['output'].name == "getitem" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 62, 62]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(getitem_strategies_vector) == len(conv_strategies_vector) + + +if __name__ == '__main__': + test_getitem_function_handler()