From 4973157ad7b9de399e5092b020aa7376de6b873e Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 12 Oct 2022 11:16:18 +0800 Subject: [PATCH] [autoparallel] added sharding spec conversion for linear handler (#1687) --- .../solver/op_handler/dot_handler_v2.py | 110 +++++++++++++----- .../solver/op_handler/node_handler.py | 22 +++- .../auto_parallel/solver/op_handler/utils.py | 68 +++++++++++ .../auto_parallel/solver/sharding_strategy.py | 25 +++- colossalai/tensor/sharding_spec.py | 6 + .../test_linear_handler_v2.py | 42 ++++++- 6 files changed, 226 insertions(+), 47 deletions(-) create mode 100644 colossalai/auto_parallel/solver/op_handler/utils.py diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py index 1f7fc4ad9..d18a6e88a 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py @@ -1,10 +1,13 @@ import torch import torch.nn.functional as F +from colossalai.tensor.sharding_spec import ShardingException from .node_handler import ModuleHandler, NodeHandler from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2, BatchedMatMulStrategyGenerator -from typing import List, Dict +from typing import List, Dict, Union from .registry import operator_registry +from copy import deepcopy +from .utils import switch_partition_dim, update_partition_dim __all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler'] @@ -24,14 +27,22 @@ class LinearModuleHandler(ModuleHandler): def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process + input_meta_data = self.node.args[0]._meta_data + input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape physical_input_operand = OperationData(name=str(self.node.args[0]), type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + data=input_meta_data, + logical_shape=input_logical_shape) physical_other_operand = OperationData(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight'], logical_shape=self.named_parameters['weight'].shape[::-1]) - physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + output_meta_data = self.node._meta_data + output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape + physical_output = OperationData(name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} @@ -42,28 +53,46 @@ class LinearModuleHandler(ModuleHandler): mapping['bias'] = physical_bias_operand return mapping - def post_process(self, strategy: ShardingStrategy_V2): + def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]: """ - Convert the sharding spec of the weight parameter back to its original shape. + Convert the sharding spec from the logical shape to the physical shape. """ + # switch the dimensions of the transposed weight for op_data, sharding_spec in strategy.input_sharding_specs.items(): if op_data.name == "weight": assert op_data.logical_shape != op_data.data.shape - dim_partition_dict = sharding_spec.dim_partition_dict - - # switch first and last dim of the linear module weight - first_dim_partition = dim_partition_dict.pop(-1, None) - last_dim_partition = dim_partition_dict.pop(0, None) - - if first_dim_partition: - dim_partition_dict[0] = first_dim_partition - - if last_dim_partition: - dim_partition_dict[-1] = last_dim_partition + switch_partition_dim(sharding_spec, 0, -1) + + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at dim 0 to one of the first few dimensions of the input + sharding_strategies = [] + input_op_data = strategy.get_op_data_by_name(str(self.node.args[0])) + output_op_data = strategy.get_op_data_by_name(str(self.node)) + num_input_dims = input_op_data.data.dim() + input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + + if 0 in input_sharding_spec.dim_partition_dict: + for i in range(num_input_dims - 1): + new_strategy = strategy.clone() + input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name) + try: + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={0: i}, + physical_shape=input_op_data.data.shape, + inplace=True) + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping={0: i}, + physical_shape=output_op_data.data.shape, + inplace=True) + sharding_strategies.append(new_strategy) + except ShardingException: + pass + else: + sharding_strategies.append(strategy) - # re-init the sharding spec - sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) - return strategy + return sharding_strategies @operator_registry.register(F.linear) @@ -118,20 +147,37 @@ class LinearFunctionHandler(NodeHandler): for op_data, sharding_spec in strategy.input_sharding_specs.items(): if op_data.name == str(self.node.args[1]): assert op_data.logical_shape != op_data.data.shape - dim_partition_dict = sharding_spec.dim_partition_dict - - # switch first and last dim of the linear module weight - first_dim_partition = dim_partition_dict.pop(-1, None) - last_dim_partition = dim_partition_dict.pop(0, None) - - if first_dim_partition: - dim_partition_dict[0] = first_dim_partition - - if last_dim_partition: - dim_partition_dict[-1] = last_dim_partition + switch_partition_dim(sharding_spec, 0, -1) + + # create multiple sharding strategies for the inputs + # as input can be multi-dimensinal and the partition dim is only 2D, + # we need to map the partition at dim 0 to one of the first few dimensions of the input + sharding_strategies = [] + input_op_data = strategy.get_op_data_by_name(str(self.node.args[0])) + output_op_data = strategy.get_op_data_by_name(str(self.node)) + num_input_dims = input_op_data.data.dim() + input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) + + if 0 in input_sharding_spec.dim_partition_dict: + for i in range(num_input_dims - 1): + new_strategy = strategy.clone() + input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name) + output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name) + try: + update_partition_dim(sharding_spec=input_sharding_spec, + dim_mapping={0: i}, + physical_shape=input_op_data.data.shape, + inplace=True) + update_partition_dim(sharding_spec=output_sharding_spec, + dim_mapping={0: i}, + physical_shape=output_op_data.data.shape, + inplace=True) + sharding_strategies.append(new_strategy) + except ShardingException: + pass + else: + sharding_strategies.append(strategy) - # re-init the sharding spec - sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) return strategy diff --git a/colossalai/auto_parallel/solver/op_handler/node_handler.py b/colossalai/auto_parallel/solver/op_handler/node_handler.py index e137f6fad..ba13591a6 100644 --- a/colossalai/auto_parallel/solver/op_handler/node_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/node_handler.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from torch.fx.node import Node from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from typing import Dict, List +from typing import Dict, List, Union from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem from ..strategy import StrategyGenerator_V2 @@ -72,17 +72,27 @@ class NodeHandler(ABC): for generator in strategy_generators: strategies = generator.generate() + # postprocess a strategy + # postprocess can produce one strategy or multiple strategies + post_processed_strategies_map = map(self.post_process, strategies) + post_processed_strategies = [] + + for strategy in post_processed_strategies_map: + if isinstance(strategy, (list, tuple)): + post_processed_strategies.extend(strategy) + else: + post_processed_strategies.append(strategy) + # compute the resharding costs based on the previous node # strategies if specified if compute_resharding_cost: - strategies = list(map(self.update_resharding_cost, strategies)) - self.strategies_vector.extend(strategies) + post_processed_strategies = list(map(self.update_resharding_cost, post_processed_strategies)) + + self.strategies_vector.extend(post_processed_strategies) - strategies_vector = map(self.post_process, self.strategies_vector) - self.strategies_vector = list(strategies_vector) return self.strategies_vector - def post_process(self, strategy: ShardingStrategy_V2): + def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]: # tranform the strategy generated # e.g. to process the sharding strategy for the transposed weights return strategy diff --git a/colossalai/auto_parallel/solver/op_handler/utils.py b/colossalai/auto_parallel/solver/op_handler/utils.py new file mode 100644 index 000000000..59bd2f535 --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/utils.py @@ -0,0 +1,68 @@ +import torch +from typing import Dict +from colossalai.tensor.sharding_spec import ShardingSpec +from copy import deepcopy + + +def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec: + """ + Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place. + + Args: + sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched + dim1 (int): the tensor dimension to switch + dim2 (int): the tensor dimension to switch + """ + assert len(sharding_spec.entire_shape) == 2 + dim_partition_dict = sharding_spec.dim_partition_dict + dim1_partition = dim_partition_dict.pop(dim1, None) + dim2_partition = dim_partition_dict.pop(dim2, None) + + if dim1_partition: + dim_partition_dict[dim2] = dim1_partition + + if dim2_partition: + dim_partition_dict[dim1] = dim2_partition + + # re-init the sharding spec + sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) + return sharding_spec + + +def update_partition_dim(sharding_spec: ShardingSpec, + dim_mapping: Dict[int, int], + physical_shape: torch.Size, + inplace: bool = False): + """ + This method is used to update the partition dim dict from the logical one to the physical one. + + Args: + sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated + dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension + physical_shape (torch.Size): the physical shape for the tensor + """ + + if inplace: + current_sharding_spec = sharding_spec + else: + current_sharding_spec = deepcopy(sharding_spec) + + old_dim_partition_dict = current_sharding_spec.dim_partition_dict + new_dim_partition_dict = {} + + # assign new dim + for old_dim, new_dim in dim_mapping.items(): + mesh_dims = old_dim_partition_dict.pop(old_dim) + new_dim_partition_dict[new_dim] = mesh_dims + + for tensor_dim, mesh_dims in old_dim_partition_dict.items(): + if tensor_dim in new_dim_partition_dict: + raise KeyError(f"There are duplicated entries for the tensor sharding dimension {tensor_dim}") + else: + new_dim_partition_dict[tensor_dim] = mesh_dims + + # update sharding spec + current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh, + entire_shape=physical_shape, + dim_partition_dict=new_dim_partition_dict) + return current_sharding_spec diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index d101589b2..64e0ea779 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -1,3 +1,4 @@ +from copy import deepcopy from dataclasses import dataclass from abc import ABC, abstractmethod from enum import Enum @@ -121,16 +122,12 @@ class ShardingStrategy_V2: communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None) memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None) input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes. - input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list - with j-th strategy in its strategies_vector transforms to sharding spec wanted in this - strategy.(default to None) """ name: str sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None compute_cost: TrainCycleItem = None communication_cost: TrainCycleItem = None memory_cost: TrainCycleItem = None - input_resharding_costs: Dict[OperationData, List[float]] = None communication_actions: Dict[OperationData, CommSpec] = None resharding_costs: Dict[OperationData, Dict[ShardingSpec, TrainCycleItem]] = None @@ -169,6 +166,26 @@ class ShardingStrategy_V2: return sharding_spec raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}") + def clone(self): + + def _deepcopy_dict_vals(data: Dict): + return {k: deepcopy(v) for k, v in data.items()} + + sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None + communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None + resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None + compute_cost = deepcopy(self.compute_cost) + communication_cost = deepcopy(self.communication_cost) + memory_cost = deepcopy(self.memory_cost) + + return ShardingStrategy_V2(name=self.name, + sharding_specs=sharding_specs, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + communication_actions=communication_actions, + resharding_costs=resharding_costs) + class StrategiesVector(list): ''' diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index 50083ca5b..fe33baf65 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -6,6 +6,8 @@ from enum import Enum from functools import reduce import operator +__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] + ALLGATHER_COST = 20 SHARD_COST = 5 STEP_PENALTY = 6 @@ -136,6 +138,10 @@ class _DimSpec: return difference +class ShardingException(Exception): + pass + + class ShardingSpec: ''' Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong diff --git a/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py index 993930060..7ef8b9e68 100644 --- a/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py +++ b/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py @@ -3,14 +3,15 @@ import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy_V2 from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec def test_linear_module_handler(): model = nn.Sequential(nn.Linear(16, 32).to('meta')) tracer = ColoTracer() - graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) + graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) gm = ColoGraphModule(model, graph) physical_mesh_id = torch.arange(0, 4) @@ -34,9 +35,9 @@ def test_linear_module_handler(): assert mapping['input'].name == "input_1" assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16]) + assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16]) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16]) + assert mapping['input'].logical_shape == torch.Size([16, 16]) assert mapping['other'].name == "weight" assert mapping['other'].data.is_meta @@ -52,11 +53,14 @@ def test_linear_module_handler(): assert mapping['output'].name == "_0" assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 32]) + assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32]) assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping['output'].logical_shape == torch.Size([16, 32]) strategies_vector = handler.register_strategy() strategy_name_list = [val.name for val in strategies_vector] + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) > 8 # SS = SR x RS assert 'S0S1 = S0R x RS1' in strategy_name_list @@ -78,6 +82,19 @@ def test_linear_module_handler(): assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list + for strategy in strategies_vector: + strategy: ShardingStrategy_V2 + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + def test_linear_function_handler(): model = nn.Linear(16, 32).to('meta') @@ -123,6 +140,8 @@ def test_linear_function_handler(): strategies_vector = handler.register_strategy() strategy_name_list = [val.name for val in strategies_vector] + # one strategy will be converted to different physical sharding spec + assert len(strategy_name_list) > 8 # SS = SR x RS assert 'S0S1 = S0R x RS1' in strategy_name_list @@ -144,6 +163,19 @@ def test_linear_function_handler(): assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list + for strategy in strategies_vector: + strategy: ShardingStrategy_V2 + input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') + weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') + bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + + # make sure the sharding matches across different operation data + assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] + assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] + assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] + assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] + if __name__ == '__main__': test_linear_module_handler()