mirror of https://github.com/hpcaitech/ColossalAI
[autockpt] make it work. (#2257)
parent
ac3739930d
commit
3ccf58aa76
|
@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||||
return meta_info
|
return meta_info
|
||||||
|
|
||||||
|
|
||||||
def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo:
|
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||||
"""
|
"""
|
||||||
This method is used to construct `MetaInto` for shape consistency node
|
This method is used to construct `MetaInto` for shape consistency node
|
||||||
"""
|
"""
|
||||||
|
@ -62,8 +62,8 @@ def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_s
|
||||||
# extract node index and user node index
|
# extract node index and user node index
|
||||||
args = node.args
|
args = node.args
|
||||||
node_index, user_node_index = args[3], args[4]
|
node_index, user_node_index = args[3], args[4]
|
||||||
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[
|
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
||||||
node_index][user_node_index]
|
user_node_index]
|
||||||
|
|
||||||
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||||
|
|
||||||
|
@ -98,16 +98,16 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M
|
||||||
return meta_info
|
return meta_info
|
||||||
|
|
||||||
|
|
||||||
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict,
|
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
|
||||||
comm_actions_dict: Dict):
|
comm_actions_dict: Dict) -> GraphModule:
|
||||||
"""
|
"""
|
||||||
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
|
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
|
||||||
"""
|
"""
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.target == runtime_apply:
|
if node.target == runtime_apply:
|
||||||
setattr(node, 'best_metainfo',
|
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||||
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict))
|
|
||||||
elif node.target == runtime_comm_spec_apply:
|
elif node.target == runtime_comm_spec_apply:
|
||||||
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
return gm
|
||||||
|
|
|
@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
|
||||||
|
|
||||||
|
|
||||||
@operator_registry.register(BCAST_FUNC_OP)
|
@operator_registry.register(BCAST_FUNC_OP)
|
||||||
class BinaryElementwiseHandler(NodeHandler):
|
class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
||||||
operands and broadcasting occurs such as torch.add.
|
operands and broadcasting occurs such as torch.add.
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType
|
from ..sharding_strategy import OperationData, OperationDataType
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import ReshapeGenerator, StrategyGenerator
|
from .strategy import ReshapeGenerator, StrategyGenerator
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler']
|
||||||
@operator_registry.register(torch.flatten)
|
@operator_registry.register(torch.flatten)
|
||||||
@operator_registry.register(torch.Tensor.unsqueeze)
|
@operator_registry.register(torch.Tensor.unsqueeze)
|
||||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||||
class ReshapeHandler(NodeHandler):
|
class ReshapeHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType
|
from ..sharding_strategy import OperationData, OperationDataType
|
||||||
from .node_handler import NodeHandler
|
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
|
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler']
|
||||||
@operator_registry.register(torch.nn.modules.dropout.Dropout)
|
@operator_registry.register(torch.nn.modules.dropout.Dropout)
|
||||||
@operator_registry.register(torch.Tensor.contiguous)
|
@operator_registry.register(torch.Tensor.contiguous)
|
||||||
@operator_registry.register(torch.nn.functional.dropout)
|
@operator_registry.register(torch.nn.functional.dropout)
|
||||||
class UnaryElementwiseHandler(NodeHandler):
|
class UnaryElementwiseHandler(MetaInfoNodeHandler):
|
||||||
"""
|
"""
|
||||||
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue