[autockpt] make it work. (#2257)

pull/2261/head
Super Daniel 2023-01-02 23:37:45 +08:00 committed by GitHub
parent ac3739930d
commit 3ccf58aa76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 12 deletions

View File

@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
return meta_info return meta_info
def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo: def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
""" """
This method is used to construct `MetaInto` for shape consistency node This method is used to construct `MetaInto` for shape consistency node
""" """
@ -62,8 +62,8 @@ def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_s
# extract node index and user node index # extract node index and user node index
args = node.args args = node.args
node_index, user_node_index = args[3], args[4] node_index, user_node_index = args[3], args[4]
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[ origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
node_index][user_node_index] user_node_index]
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
@ -98,16 +98,16 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M
return meta_info return meta_info
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict, def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
comm_actions_dict: Dict): comm_actions_dict: Dict) -> GraphModule:
""" """
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
""" """
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.target == runtime_apply: if node.target == runtime_apply:
setattr(node, 'best_metainfo', setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict))
elif node.target == runtime_comm_spec_apply: elif node.target == runtime_comm_spec_apply:
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
else: else:
pass pass
return gm

View File

@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
@operator_registry.register(BCAST_FUNC_OP) @operator_registry.register(BCAST_FUNC_OP)
class BinaryElementwiseHandler(NodeHandler): class BinaryElementwiseHandler(MetaInfoNodeHandler):
""" """
An BinaryBcastOpHandler is a node handler which deals with operations which have two An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add. operands and broadcasting occurs such as torch.add.

View File

@ -3,7 +3,7 @@ from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import ReshapeGenerator, StrategyGenerator from .strategy import ReshapeGenerator, StrategyGenerator
@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler']
@operator_registry.register(torch.flatten) @operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze) @operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d) @operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler): class ReshapeHandler(MetaInfoNodeHandler):
""" """
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
""" """

View File

@ -3,7 +3,7 @@ from typing import Dict, List
import torch import torch
from ..sharding_strategy import OperationData, OperationDataType from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator from .strategy import StrategyGenerator, UnaryElementwiseGenerator
@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.nn.modules.dropout.Dropout) @operator_registry.register(torch.nn.modules.dropout.Dropout)
@operator_registry.register(torch.Tensor.contiguous) @operator_registry.register(torch.Tensor.contiguous)
@operator_registry.register(torch.nn.functional.dropout) @operator_registry.register(torch.nn.functional.dropout)
class UnaryElementwiseHandler(NodeHandler): class UnaryElementwiseHandler(MetaInfoNodeHandler):
""" """
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
""" """