diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index 5ab6289b7..ab3acb056 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, return meta_info -def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo: +def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo: """ This method is used to construct `MetaInto` for shape consistency node """ @@ -62,8 +62,8 @@ def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_s # extract node index and user node index args = node.args node_index, user_node_index = args[3], args[4] - origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[ - node_index][user_node_index] + origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ + user_node_index] return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) @@ -98,16 +98,16 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M return meta_info -def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict, - comm_actions_dict: Dict): +def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, + comm_actions_dict: Dict) -> GraphModule: """ The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. """ for node in gm.graph.nodes: if node.target == runtime_apply: - setattr(node, 'best_metainfo', - _runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict)) + setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) elif node.target == runtime_comm_spec_apply: setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) else: pass + return gm diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index e8ae363e9..f510f7477 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler'] @operator_registry.register(BCAST_FUNC_OP) -class BinaryElementwiseHandler(NodeHandler): +class BinaryElementwiseHandler(MetaInfoNodeHandler): """ An BinaryBcastOpHandler is a node handler which deals with operations which have two operands and broadcasting occurs such as torch.add. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py index b46348716..7763b1884 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py @@ -3,7 +3,7 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import ReshapeGenerator, StrategyGenerator @@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler'] @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.unsqueeze) @operator_registry.register(torch.nn.AdaptiveAvgPool2d) -class ReshapeHandler(NodeHandler): +class ReshapeHandler(MetaInfoNodeHandler): """ A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index bda160906..0362de780 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,7 +3,7 @@ from typing import Dict, List import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import NodeHandler +from .node_handler import MetaInfoNodeHandler, NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator @@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler'] @operator_registry.register(torch.nn.modules.dropout.Dropout) @operator_registry.register(torch.Tensor.contiguous) @operator_registry.register(torch.nn.functional.dropout) -class UnaryElementwiseHandler(NodeHandler): +class UnaryElementwiseHandler(MetaInfoNodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. """