[autoparallel] use metainfo in handler (#2149)

pull/2118/head
YuliangLiu0306 2022-12-20 10:31:22 +08:00 committed by GitHub
parent 9b39170a5c
commit 1cce6e36ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 105 additions and 31 deletions

View File

@ -28,7 +28,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
inplace = kwargs.get("inplace", False)

View File

@ -58,9 +58,12 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
"""
has_bias: bool = False
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
if len(args) == 4:
weight_tensors = [args[1].data, args[3].data]
else:
weight_tensors = [args[1].data]
# check if conv has bias
if len(weight_tensors) > 1:

View File

@ -66,9 +66,13 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
"""
has_bias: bool = False
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
input_tensor = args[0].data
output_tensor = args[2].data
if len(args) == 4:
weight_tensors = [args[1].data, args[3].data]
else:
weight_tensors = [args[1].data]
# process the dimension of input and output
if len(input_tensor.shape) > 2:

View File

@ -45,7 +45,7 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data

View File

@ -30,7 +30,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
input_tensor = args[0].data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
# construct forward args for flop mapping

View File

@ -2,8 +2,10 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
from .node_handler import MetaInfoModuleHandler, ModuleHandler
from .registry import operator_registry
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
@ -13,7 +15,7 @@ __all__ = ['BatchNormModuleHandler']
@operator_registry.register(torch.nn.BatchNorm1d)
@operator_registry.register(torch.nn.BatchNorm2d)
@operator_registry.register(torch.nn.BatchNorm3d)
class BatchNormModuleHandler(ModuleHandler):
class BatchNormModuleHandler(MetaInfoModuleHandler):
"""
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
"""

View File

@ -3,18 +3,12 @@ from typing import Dict, List, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
ShardingStrategy,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..constants import BCAST_FUNC_OP
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
@ -22,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
@operator_registry.register(BCAST_FUNC_OP)
class BinaryElementwiseHandler(NodeHandler):
class BinaryElementwiseHandler(MetaInfoNodeHandler):
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.

View File

@ -3,9 +3,9 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..utils import transpose_partition_dim
from .node_handler import ModuleHandler, NodeHandler
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
@operator_registry.register(torch.nn.Conv1d)
@operator_registry.register(torch.nn.Conv2d)
@operator_registry.register(torch.nn.Conv3d)
class ConvModuleHandler(ModuleHandler):
class ConvModuleHandler(MetaInfoModuleHandler):
"""
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
"""
@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler):
@operator_registry.register(F.conv1d)
@operator_registry.register(F.conv2d)
@operator_registry.register(F.conv3d)
class ConvFunctionHandler(NodeHandler):
class ConvFunctionHandler(MetaInfoNodeHandler):
"""
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
"""

View File

@ -3,12 +3,16 @@ from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
from colossalai.auto_parallel.tensor_shard.utils import (
check_sharding_spec_validity,
transpose_partition_dim,
update_partition_dim,
)
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
@ -139,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
@operator_registry.register(torch.nn.Linear)
class LinearModuleHandler(ModuleHandler):
class LinearModuleHandler(MetaInfoModuleHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""
@ -199,7 +203,7 @@ class LinearModuleHandler(ModuleHandler):
@operator_registry.register(F.linear)
class LinearFunctionHandler(NodeHandler):
class LinearFunctionHandler(MetaInfoNodeHandler):
"""
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
"""

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@ -133,6 +134,26 @@ class NodeHandler(ABC):
strategy.resharding_costs = resharding_costs
return strategy
def get_target_function(self) -> callable:
"""
This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies.
"""
if self.node.op in ('placeholder', 'get_attr', 'output'):
return None
if self.node.op == 'call_module':
submod = self.node.graph.owning_module.get_submodule(self.node.target)
target = type(submod)
elif self.node.op == 'call_function':
target = self.node.target
elif self.node.op == 'call_method':
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
else:
raise ValueError(f'Unsupported node type: {self.node.op}')
return target
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
Register different sharding strategies for the current node.
@ -204,6 +225,29 @@ class NodeHandler(ABC):
pass
class MetaInfoNodeHandler(NodeHandler):
"""
This is a base class to handle the nodes patched in the meta profiler.
Note: this class will be integrated into the NodeHandler class in the future, after
all the functions are patched.
"""
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
return self.strategies_vector
class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None:
@ -221,3 +265,26 @@ class ModuleHandler(NodeHandler):
self.module = module
self.named_parameters = named_parameters
self.named_buffers = named_buffers
class MetaInfoModuleHandler(ModuleHandler):
"""
This is a base class to handle the module patched in the meta profiler.
Note: this class will be integrated into the ModuleHandler class in the future, after
all the modules are patched.
"""
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
return self.strategies_vector

View File

@ -3,7 +3,7 @@ from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import ModuleHandler
from .node_handler import MetaInfoModuleHandler, ModuleHandler
from .registry import operator_registry
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler']
@operator_registry.register(torch.nn.AvgPool1d)
@operator_registry.register(torch.nn.AvgPool2d)
@operator_registry.register(torch.nn.AvgPool3d)
class NormPoolingHandler(ModuleHandler):
class NormPoolingHandler(MetaInfoModuleHandler):
"""
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
"""