From a4ce180e8500b207835e82f5820e9c698261e37f Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 20 Oct 2022 18:48:18 +0800 Subject: [PATCH] [autoparallel] add sequential order to communication actions (#1735) --- .../strategy/conv_strategy_generator.py | 128 +++++++++++------- .../strategy/reshape_generator.py | 30 +++- .../strategy/strategy_generator.py | 56 ++++++-- .../tensor_shard/sharding_strategy.py | 41 +++++- .../adding_shape_consistency_pass_v2.py | 68 +++++++++- colossalai/tensor/comm_spec.py | 14 +- .../test_shape_consistency_pass.py | 46 +++++-- 7 files changed, 293 insertions(+), 90 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index 427eea671..83476e4fe 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -4,9 +4,18 @@ import warnings from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) + from colossalai.auto_parallel.tensor_shard.utils import \ ignore_sharding_exception + from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -122,26 +131,28 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_1) - - communication_action_mapping = {"input": input_comm_spec} + logical_process_axis=mesh_dim_1, + comm_type=CommType.BEFORE) + communication_action_mapping = {"input": input_comm_action} if self.is_param("other"): - other_comm_spec = self.get_communication_spec( + other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping["other"] = other_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + communication_action_mapping["other"] = other_comm_action if self.has_bias and self.is_param("bias"): - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping["bias"] = bias_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -167,18 +178,20 @@ class ConvStrategyGenerator(StrategyGenerator): communication_action_mapping = {} if self.is_param("other"): - other_comm_spec = self.get_communication_spec( + other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping["other"] = other_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + communication_action_mapping["other"] = other_comm_action if self.has_bias and self.is_param("bias"): - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping["bias"] = bias_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -206,26 +219,30 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_1) + logical_process_axis=mesh_dim_1, + comm_type=CommType.AFTER, + arg_index=0) - communication_action_mapping = {"output": output_comm_spec} + communication_action_mapping = {"output": output_comm_action} if self.is_param("other"): - other_comm_spec = self.get_communication_spec( + other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping["other"] = other_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + communication_action_mapping["other"] = other_comm_action if self.has_bias and self.is_param("bias"): - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping["bias"] = bias_comm_spec + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -256,16 +273,20 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_0) - input_comm_spec = self.get_communication_spec( + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER, + arg_index=0) + input_comm_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + arg_index=0) - communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec} + communication_action_mapping = {"output": output_comm_action, "input": input_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -291,12 +312,14 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=mesh_dim_0) + logical_process_axis=mesh_dim_0, + comm_type=CommType.AFTER, + arg_index=0) - communication_action_mapping = {"output": output_comm_spec} + communication_action_mapping = {"output": output_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -324,12 +347,13 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE) - communication_action_mapping = {"input": input_comm_spec} + communication_action_mapping = {"input": input_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -375,18 +399,20 @@ class ConvStrategyGenerator(StrategyGenerator): communication_action_mapping = {} if self.is_param("other"): - other_comm_spec = self.get_communication_spec( + other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping["other"] = other_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + communication_action_mapping["other"] = other_comm_action if self.has_bias and self.is_param("bias"): - bias_comm_spec = self.get_communication_spec( + bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping["bias"] = bias_comm_spec + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -411,12 +437,14 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action - output_comm_spec = self.get_communication_spec( + output_comm_action = self.get_communication_action( sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.AFTER, + arg_index=0) - communication_action_mapping = {"output": output_comm_spec} + communication_action_mapping = {"output": output_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -443,12 +471,14 @@ class ConvStrategyGenerator(StrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action - input_comm_spec = self.get_communication_spec( + input_comm_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + arg_index=0) - communication_action_mapping = {"input": input_comm_spec} + communication_action_mapping = {"input": input_comm_action} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index 8fa5a8137..cbe0f0746 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -1,8 +1,15 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec from .strategy_generator import FollowingStrategyGenerator @@ -81,12 +88,23 @@ class ReshapeGenerator(FollowingStrategyGenerator): # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. if len(total_mesh_dim_list) == 1: total_mesh_dim_list = total_mesh_dim_list[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.BEFORE, + arg_index=0) + input_comm_action.comm_spec.gather_dim = total_mesh_dim_list - input_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping["input"], - communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - logical_process_axis=total_mesh_dim_list) - communication_action_mapping["input"] = input_comm_spec + else: + source_spec = sharding_spec_mapping["input"] + target_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=source_spec.entire_shape, + dim_partition_dict={}) + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + communication_action_mapping["input"] = input_comm_action strategy = self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index 6196e8336..6bbb15e57 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -4,17 +4,27 @@ from functools import reduce from typing import Any, Dict, List, Union import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, - TrainCycleItem) + +from torch.fx import Node + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + OperationData, + OperationDataType, + ShardingStrategy, + TrainCycleItem, +) + from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from torch.fx import Node class StrategyGenerator(ABC): """ - StrategyGenerator is used to generate the same group of sharding strategies. + StrategyGenerator is used to generate the same group of sharding strategies. TODO: remove the original strategy_generator.py after refactoring """ @@ -97,6 +107,21 @@ class StrategyGenerator(ABC): sharding_spec=sharding_spec, logical_process_axis=logical_process_axis) + def get_communication_action(self, + sharding_spec: ShardingSpec, + communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]], + comm_type: CommType, + arg_index: int = -1) -> CommAction: + """ + A factory method to produce a CommAction object. + """ + return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec, + communication_pattern=communication_pattern, + logical_process_axis=logical_process_axis), + comm_type=comm_type, + arg_index=arg_index) + def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ Compute the communication cost involved in the forward and backward iteration. @@ -117,8 +142,21 @@ class StrategyGenerator(ABC): # check if communication action exists # if so, loop over each action and compute the cost of each action if strategy.communication_actions is not None: - for operand, comm_spec in strategy.communication_actions.items(): - _compute_and_add(operand, comm_spec) + for operand, comm_action in strategy.communication_actions.items(): + if isinstance(comm_action, CommAction): + comm_spec = comm_action.comm_spec + else: + # this condition branch will be removed after all the handler updated. + comm_spec = comm_action + if isinstance(comm_spec, dict): + src_spec = comm_spec['src_spec'] + tgt_spec = comm_spec['tgt_spec'] + shape_consistency_manager = ShapeConsistencyManager() + _, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec) + for comm_spec_ in comm_action_sequence: + _compute_and_add(operand, comm_spec_) + else: + _compute_and_add(operand, comm_spec) # update the communication cost attribute in-place strategy.communication_cost = comm_cost @@ -141,7 +179,7 @@ class StrategyGenerator(ABC): def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str): """ Compute the size of a tensor in bytes. - + Args: strategy (ShardingStrategy): the ShardingStrategy generated. key (str): the name of the operation data defined by the generator. @@ -182,7 +220,7 @@ class StrategyGenerator(ABC): @abstractmethod def validate(self) -> bool: """ - Validate if the operands are of desired shape. + Validate if the operands are of desired shape. If True, means this generator can be used for the current operation. """ pass @@ -190,7 +228,7 @@ class StrategyGenerator(ABC): class FollowingStrategyGenerator(StrategyGenerator): """ - FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node. + FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node. TODO: remove the original strategy_generator.py after refactoring """ diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index ed5731e9a..8dbb0014b 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -4,11 +4,12 @@ from enum import Enum from typing import Any, Dict, List, Tuple, Union import torch -from colossalai.tensor.shape_consistency import CommSpec -from colossalai.tensor.sharding_spec import ShardingSpec from torch.fx.node import Node -from .constants import (BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP) +from colossalai.tensor.shape_consistency import CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec + +from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP __all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] @@ -84,6 +85,38 @@ class MemoryCost: buffer: int = 0 +class CommType(Enum): + """ + CommType describes the sequential order of a communication action and a computation action. + + Meaning: + BEFORE: the communication action happens just before the computation operation. + AFTER: the communication action happens after the computation operation. + HOOK: the communication action is used to do the grad all reduce. + IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm + """ + BEFORE = 0 + AFTER = 1 + HOOK = 2 + IMPLICIT = 3 + + +@dataclass +class CommAction: + """ + CommAction is used to record the communication action. + + Args: + comm_spec: express the communication pattern and the process groups to execute the communication action. + comm_type: describes the sequential order of a communication action and a computation action. + arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime, + because the args of node may be changed by graph transform passes. + """ + comm_spec: CommSpec = None + comm_type: CommType = None + arg_index: int = -1 + + @dataclass class ShardingStrategy: """ @@ -102,7 +135,7 @@ class ShardingStrategy: compute_cost: TrainCycleItem = None communication_cost: TrainCycleItem = None memory_cost: TrainCycleItem = None - communication_actions: Dict[OperationData, CommSpec] = None + communication_actions: Dict[OperationData, CommAction] = None resharding_costs: Dict[Node, List[TrainCycleItem]] = None @property diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py index a9e7109f2..d40ab0f0c 100644 --- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py +++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py @@ -8,8 +8,10 @@ import torch from torch.fx import symbolic_trace from torch.fx.node import Node +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.passes.split_module import split_module +from colossalai.tensor.comm_spec import CommSpec, _all_reduce from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec @@ -19,9 +21,9 @@ shape_consistency_manager = ShapeConsistencyManager() class ConsistencyApply(torch.autograd.Function): @staticmethod - def forward(ctx, node, origin_dict, input_dict, node_index, user_node_index): - ctx.origin_sharding_spec = origin_dict[node_index] - ctx.target_sharding_spec = input_dict[node_index][user_node_index] + def forward(ctx, node, origin_sharding_spec, target_sharding_spec): + ctx.origin_sharding_spec = origin_sharding_spec + ctx.target_sharding_spec = target_sharding_spec return shape_consistency_manager.apply_for_autoparallel_runtime(node, ctx.origin_sharding_spec, ctx.target_sharding_spec) @@ -32,7 +34,9 @@ class ConsistencyApply(torch.autograd.Function): def runtime_apply_for_leaf_node(node, origin_dict, input_dict, node_index, user_node_index): - return ConsistencyApply.apply(node, origin_dict, input_dict, node_index, user_node_index) + origin_sharding_spec = origin_dict[node_index] + target_sharding_spec = input_dict[node_index][user_node_index] + return ConsistencyApply.apply(node, origin_sharding_spec, target_sharding_spec) def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index): @@ -41,6 +45,18 @@ def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index): return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) +def runtime_comm_spec_apply(tensor, comm_actions_dict, node_index, op_data): + + comm_action = comm_actions_dict[node_index][op_data] + if isinstance(comm_action.comm_spec, CommSpec): + rst = comm_action.comm_spec.covert_spec_to_action(tensor) + else: + origin_sharding_spec = comm_action.comm_spec['src_spec'] + tgt_sharding_spec = comm_action.comm_spec['tgt_spec'] + rst = ConsistencyApply.apply(tensor, origin_sharding_spec, tgt_sharding_spec) + return rst + + def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh): mod_graph = gm.graph nodes = tuple(mod_graph.nodes) @@ -63,6 +79,16 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de setattr(param, 'sharding_spec', origin_sharding_spec) target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) shape_consistency_manager.apply(param, target_sharding_spec) + comm_actions = node.best_strategy.communication_actions + + for operation_data, comm_action in comm_actions.items(): + comm_spec_to_use = comm_action.comm_spec + if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK: + + def hook_fn(grad): + _all_reduce(grad, comm_spec_to_use) + + param.register_hook(hook_fn) for name, buffer in target_module.named_buffers(): origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {}) @@ -79,15 +105,24 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de target_sharding_specs.append(target_sharding_spec) sharding_spec_convert_dict[index] = target_sharding_specs + # the dict to record comm actions of nodes + comm_actions_dict = {} + for index, node in enumerate(nodes): + comm_action_dict = {} + for op_data, comm_action in node.best_strategy.communication_actions.items(): + comm_action_dict[op_data.name] = comm_action + comm_actions_dict[index] = comm_action_dict + # add above dicts into graph for node in nodes: if node.op != 'placeholder': with mod_graph.inserting_before(node): input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') + comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict') break - return sharding_spec_convert_dict, origin_node_sharding_spec_dict + return sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict def shape_consistency_pass(gm: torch.fx.GraphModule): @@ -106,6 +141,9 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): if node.target == 'origin_node_sharding_spec_dict': origin_dict_node = node continue + if node.target == 'comm_actions_dict': + comm_actions_dict_node = node + continue if not hasattr(node, 'best_strategy'): continue node_to_index_dict[node] = index @@ -138,4 +176,24 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): new_args[origin_index_args] = shape_consistency_node user_node.args = new_args + comm_actions = node.best_strategy.communication_actions + for op_data, comm_action in comm_actions.items(): + comm_object = node.args[comm_action.arg_index] + if op_data.type == OperationDataType.ARG: + if comm_action.comm_type == CommType.BEFORE: + with mod_graph.inserting_before(node): + comm_spec_apply_node = mod_graph.create_node('call_function', + runtime_comm_spec_apply, + args=(comm_object, comm_actions_dict_node, + node_to_index_dict[node], op_data.name)) + elif comm_action.comm_type == CommType.AFTER: + with mod_graph.inserting_after(node): + comm_spec_apply_node = mod_graph.create_node('call_function', + runtime_comm_spec_apply, + args=(comm_object, comm_actions_dict_node, + node_to_index_dict[node], op_data.name)) + # TODO: consider other OperationDataType, such as OperationDataType.OUTPUT + new_args = list(node.args) + new_args[comm_action.arg_index] = comm_spec_apply_node + node.args = new_args return gm diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 8f51f21cf..646ded54e 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -1,8 +1,9 @@ -import torch -from enum import Enum -import torch.distributed as dist -from functools import reduce import operator +from enum import Enum +from functools import reduce + +import torch +import torch.distributed as dist from torch.distributed import ReduceOp __all__ = [ @@ -238,7 +239,7 @@ class CommSpec: 1. Compute the communication cost which will be used in auto parallel solver. 2. Convert the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the - communication method, sharding_spec to determine the communication size, gather_dim and shard_dim + communication method, sharding_spec to determine the communication size, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: @@ -296,7 +297,7 @@ class CommSpec: ''' For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to compute the communication cost. - For shard operation, it is an on-chip operation, so the communication cost is zero. + For shard operation, it is an on-chip operation, so the communication cost is zero. ''' comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) cost_dict = {} @@ -347,6 +348,7 @@ class CommSpec: tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self) else: tensor.data = tensor + return tensor pattern_to_func_dict = { diff --git a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py index ae15106b0..7dd0ae842 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py @@ -1,3 +1,4 @@ +import copy from functools import partial import pytest @@ -6,15 +7,22 @@ import torch.multiprocessing as mp import torch.nn as nn from torch.fx import GraphModule -from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions, - StrategiesConstructor) +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass, - solution_annotatation_pass) +from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import ( + shape_consistency_pass, + solution_annotatation_pass, +) from colossalai.fx.tracer.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import free_port @@ -27,6 +35,7 @@ class ConvModel(nn.Module): def forward(self, x): x = self.conv(x) + x = torch.flatten(x) return x @@ -38,12 +47,13 @@ def check_apply(rank, world_size, port): mesh_shape = (2, 2) # [[0, 1] # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False) - entire_shape = torch.Size((4, 4, 8, 8)) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) tracer = ColoTracer() model = ConvModel(4, 4).cuda() - origin_output = model(input) + test_model = copy.deepcopy(model) + test_input = copy.deepcopy(input) + input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')} # graph(): # %x : torch.Tensor [#users=1] = placeholder[target=x] @@ -62,16 +72,30 @@ def check_apply(rank, world_size, port): solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) ret = solver.call_solver_serialized_args() solution = list(ret[0]) - device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh() - sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh) + sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh) shape_consistency_pass(gm) gm.recompile() nodes = [node for node in gm.graph.nodes] # TODO: wrap the gm to avoid the influence of the user training code - output = gm(input, sharding_spec_dict, origin_spec_dict) + output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + origin_output = test_model(test_input) assert output.equal(origin_output) + origin_loss = origin_output.sum() + loss = output.sum() + + origin_loss.backward() + loss.backward() + + grad_0 = test_model.conv.weight.grad.narrow(0, 0, 2) + grad_1 = test_model.conv.weight.grad.narrow(0, 2, 2) + + if rank in (0, 1): + assert_close(gm.conv.weight.grad.data, grad_0.data) + elif rank in (2, 3): + assert_close(gm.conv.weight.grad.data, grad_1.data) +# skip this test due to pulp not installed in CI environment @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use()