diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py index b678c59a5..8463cc62b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -3,7 +3,7 @@ from typing import Dict, List import torch import torch.nn.functional as F -from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy) +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from .node_handler import ModuleHandler, NodeHandler from .registry import operator_registry from .strategy import ConvStrategyGenerator, StrategyGenerator @@ -68,7 +68,7 @@ class ConvModuleHandler(ModuleHandler): dim_partition_dict[1] = second_dim_partition # re-init the sharding spec - sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) + sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict) return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 299184b29..8d9683766 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -46,6 +46,7 @@ class NodeHandler(ABC): # TODO: test this function when other handlers are ready resharding_costs = {} shape_consistency_manager = ShapeConsistencyManager() + for node in self.predecessor_node: node_name = str(node) @@ -54,7 +55,9 @@ class NodeHandler(ABC): assert hasattr(node, 'strategies_vector'), \ f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.' prev_strategy_vector = node.strategies_vector - prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector] + prev_sharding_specs = [ + prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector + ] # get the current sharding spec generated by this node handler op_data = strategy.get_op_data_by_name(node_name) diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py index e48f0d4e5..a9e7109f2 100644 --- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py +++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py @@ -1,15 +1,17 @@ -from ast import NodeTransformer -import torch -from typing import List -from torch.fx import symbolic_trace -from torch.fx.node import Node -from colossalai.fx.passes.split_module import split_module -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec import builtins import operator +from ast import NodeTransformer from copy import deepcopy +from typing import List + +import torch +from torch.fx import symbolic_trace +from torch.fx.node import Node + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx.passes.split_module import split_module +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec shape_consistency_manager = ShapeConsistencyManager() diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 557efda8a..fa1f663a0 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -1,16 +1,19 @@ -import torch -from dataclasses import dataclass -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator -from enum import Enum -from copy import deepcopy -from typing import Dict, List, Optional, Tuple, Union -from colossalai.context.singleton_meta import SingletonMeta -import torch.distributed as dist import math -from functools import reduce import operator +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from functools import reduce +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist from torch.distributed import ReduceOp + +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException, _DimSpec +from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator + from .comm_spec import * __all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options'] @@ -62,10 +65,10 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): def get_all_all_gather_spec(self, source_spec, orig_cost_dict): ''' - Get all valid sharding specs from source_spec with single all-gather operation, and + Get all valid sharding specs from source_spec with single all-gather operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. For the all-gather operation, we just care about the S dimension. - + Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. @@ -82,12 +85,12 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): shape_consistency_manager = ShapeConsistencyManager() rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) print(rst_dict) - + Output: - {DistSpec: - shard_sequence: R,S1,R - device_mesh_shape: (4, 4): 0, DistSpec: - shard_sequence: S0,R,R + {DistSpec: + shard_sequence: R,S1,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,R device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} @@ -120,20 +123,23 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): cost_dict = comm_spec.get_comm_cost() # generate new sharding spec - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) - for phase, cost in cost_dict.items(): - cost_dict[phase] = cost + orig_cost_dict[phase] - valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass return valid_spec_dict def get_all_all_to_all_spec(self, source_spec, orig_cost_dict): ''' - Get all valid sharding specs from source_spec with single all-to-all operation, and + Get all valid sharding specs from source_spec with single all-to-all operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. For the all-to-all operation, we just care about the pairs containing S dimension. - + Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. @@ -150,14 +156,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): shape_consistency_manager = ShapeConsistencyManager() rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0) print(rst_dict) - + Output: - {DistSpec: - shard_sequence: S01,R,R - device_mesh_shape: (4, 4): 0, DistSpec: - shard_sequence: R,S1,S0 - device_mesh_shape: (4, 4): 0, DistSpec: - shard_sequence: S0,R,S1 + {DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: R,S1,S0 + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} @@ -223,20 +229,24 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): new_dim_partition_dict.pop(b_index) # generate new sharding spec - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) - for phase, cost in cost_dict.items(): - cost_dict[phase] = cost + orig_cost_dict[phase] - valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + return valid_spec_dict def get_all_shard_spec(self, source_spec, orig_cost_dict): ''' - Get all valid sharding specs from source_spec with single shard operation, and + Get all valid sharding specs from source_spec with single shard operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. For the sharding operation, we just care about legal sharding dimensions. - + Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. @@ -253,14 +263,14 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): shape_consistency_manager = ShapeConsistencyManager() rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0) print(rst_dict) - + Output: - {DistSpec: - shard_sequence: S01,R,R - device_mesh_shape: (4, 4): 0, DistSpec: - shard_sequence: S0,S1,R - device_mesh_shape: (4, 4): 0, DistSpec: - shard_sequence: S0,R,S1 + {DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,S1,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} @@ -275,6 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): return valid_spec_dict tensor_dims = len(source_spec.entire_shape) + for index in range(tensor_dims): if index not in source_spec.dim_partition_dict: shard_list_list = shard_simulator((index, []), legal_sharding_dims) @@ -300,23 +311,26 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): cost_dict = comm_spec.get_comm_cost() # generate new sharding spec - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) - for phase, cost in cost_dict.items(): - cost_dict[phase] = cost + orig_cost_dict[phase] - valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass return valid_spec_dict def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict): ''' - Get all valid sharding specs from source_spec with one step transform, and + Get all valid sharding specs from source_spec with one step transform, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. Note: all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before, and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive, we could safely put them together. - + Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. @@ -343,7 +357,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): Repeat above steps until the source spec transform to target spec. During finding the transform path, commucation cost will be accumulated, and it - will be finally used in auto parallel solver. + will be finally used in auto parallel solver. Additionally, to avoid repeating the path search in runtime, we cached all solved path in auto parallel strategy building time, which could handle most of cases in runtime. @@ -361,30 +375,30 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): Example: dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} - # DistSpec: - # shard_sequence: R,S01,R + # DistSpec: + # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) - # DistSpec: - # shard_sequence: S01,R,R + # DistSpec: + # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target) print(f'transform_path: {transform_path}') print(f'comm_action_sequence: {comm_action_sequence}') print(f'total_cost: {total_cost}') - + output: - transform_path: [DistSpec: - shard_sequence: R,S01,R - device_mesh_shape: (4, 4), DistSpec: - shard_sequence: R,S0,R - device_mesh_shape: (4, 4), DistSpec: - shard_sequence: S0,R,R - device_mesh_shape: (4, 4), DistSpec: - shard_sequence: S01,R,R + transform_path: [DistSpec: + shard_sequence: R,S01,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: R,S0,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: S0,R,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: S01,R,R device_mesh_shape: (4, 4)] - comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), + comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0), CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)] total_cost: 12294.402000000002 @@ -403,6 +417,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): return (transform_path, comm_action_sequence, total_cost_dict) temp_sharding_spec = source_spec + transform_path.append(temp_sharding_spec) # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms while total_steps <= MAX_TRANSFORM_STEPS: @@ -437,13 +452,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): def apply(self, tensor_with_sharding_spec, target_spec): ''' - Apply target_spec to tensor with source sharding spec, the transform path is generated by the + Apply target_spec to tensor with source sharding spec, the transform path is generated by the shape_consistency method. - + Argument: tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec. target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec. - + Example: physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -459,7 +474,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # shard_sequence: S0,R # device_mesh_shape: (2, 2) sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) - + # DistSpec: # shard_sequence: R,S0 # device_mesh_shape: (2, 2) @@ -481,13 +496,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): tensor_to_comm.sharding_spec = sharding_spec_source shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target) print(tensor_to_comm) - + Output in rank0 and rank2: tensor([[0.], [0.], [2.], [2.]]) - + Output in rank1 and rank3: tensor([[1.], [1.], @@ -505,4 +520,4 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): for comm_spec in comm_action_sequence: comm_spec.covert_spec_to_action(tensor) tensor.sharding_spec = target_spec - return tensor \ No newline at end of file + return tensor