mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] use metainfo in handler (#2149)
parent
9b39170a5c
commit
1cce6e36ca
|
@ -28,7 +28,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
|
|||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
inplace = kwargs.get("inplace", False)
|
||||
|
||||
|
|
|
@ -58,9 +58,12 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
"""
|
||||
|
||||
has_bias: bool = False
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
|
||||
if len(args) == 4:
|
||||
weight_tensors = [args[1].data, args[3].data]
|
||||
else:
|
||||
weight_tensors = [args[1].data]
|
||||
|
||||
# check if conv has bias
|
||||
if len(weight_tensors) > 1:
|
||||
|
|
|
@ -66,9 +66,13 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
"""
|
||||
|
||||
has_bias: bool = False
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
|
||||
|
||||
input_tensor = args[0].data
|
||||
output_tensor = args[2].data
|
||||
if len(args) == 4:
|
||||
weight_tensors = [args[1].data, args[3].data]
|
||||
else:
|
||||
weight_tensors = [args[1].data]
|
||||
|
||||
# process the dimension of input and output
|
||||
if len(input_tensor.shape) > 2:
|
||||
|
|
|
@ -45,7 +45,7 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
|
|||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
|
||||
bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
|
||||
|
|
|
@ -30,7 +30,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
|
|||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
|
||||
input_tensor = args[0].data
|
||||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
|
||||
|
||||
# construct forward args for flop mapping
|
||||
|
|
|
@ -2,8 +2,10 @@ from typing import Dict, List
|
|||
|
||||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import ModuleHandler
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BatchNormStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
@ -13,7 +15,7 @@ __all__ = ['BatchNormModuleHandler']
|
|||
@operator_registry.register(torch.nn.BatchNorm1d)
|
||||
@operator_registry.register(torch.nn.BatchNorm2d)
|
||||
@operator_registry.register(torch.nn.BatchNorm3d)
|
||||
class BatchNormModuleHandler(ModuleHandler):
|
||||
class BatchNormModuleHandler(MetaInfoModuleHandler):
|
||||
"""
|
||||
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
|
||||
"""
|
||||
|
|
|
@ -3,18 +3,12 @@ from typing import Dict, List, Union
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
|
||||
|
||||
from ..constants import BCAST_FUNC_OP
|
||||
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
|
||||
from .node_handler import NodeHandler
|
||||
from .node_handler import MetaInfoNodeHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
@ -22,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
|
|||
|
||||
|
||||
@operator_registry.register(BCAST_FUNC_OP)
|
||||
class BinaryElementwiseHandler(NodeHandler):
|
||||
class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
||||
operands and broadcasting occurs such as torch.add.
|
||||
|
|
|
@ -3,9 +3,9 @@ from typing import Dict, List
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
|
||||
from ..utils import transpose_partition_dim
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
|
|||
@operator_registry.register(torch.nn.Conv1d)
|
||||
@operator_registry.register(torch.nn.Conv2d)
|
||||
@operator_registry.register(torch.nn.Conv3d)
|
||||
class ConvModuleHandler(ModuleHandler):
|
||||
class ConvModuleHandler(MetaInfoModuleHandler):
|
||||
"""
|
||||
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
|
||||
"""
|
||||
|
@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler):
|
|||
@operator_registry.register(F.conv1d)
|
||||
@operator_registry.register(F.conv2d)
|
||||
@operator_registry.register(F.conv3d)
|
||||
class ConvFunctionHandler(NodeHandler):
|
||||
class ConvFunctionHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
|
||||
"""
|
||||
|
|
|
@ -3,12 +3,16 @@ from typing import Dict, List, Union
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
check_sharding_spec_validity,
|
||||
transpose_partition_dim,
|
||||
update_partition_dim,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
|
||||
from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
@ -139,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
|
||||
|
||||
@operator_registry.register(torch.nn.Linear)
|
||||
class LinearModuleHandler(ModuleHandler):
|
||||
class LinearModuleHandler(MetaInfoModuleHandler):
|
||||
"""
|
||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||
"""
|
||||
|
@ -199,7 +203,7 @@ class LinearModuleHandler(ModuleHandler):
|
|||
|
||||
|
||||
@operator_registry.register(F.linear)
|
||||
class LinearFunctionHandler(NodeHandler):
|
||||
class LinearFunctionHandler(MetaInfoNodeHandler):
|
||||
"""
|
||||
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
|
||||
"""
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
|
@ -133,6 +134,26 @@ class NodeHandler(ABC):
|
|||
strategy.resharding_costs = resharding_costs
|
||||
return strategy
|
||||
|
||||
def get_target_function(self) -> callable:
|
||||
"""
|
||||
This function is used to get the target function for the node handler.
|
||||
The target function is used to analyze the costs of strategies.
|
||||
"""
|
||||
if self.node.op in ('placeholder', 'get_attr', 'output'):
|
||||
return None
|
||||
|
||||
if self.node.op == 'call_module':
|
||||
submod = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||
target = type(submod)
|
||||
elif self.node.op == 'call_function':
|
||||
target = self.node.target
|
||||
elif self.node.op == 'call_method':
|
||||
target = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
|
||||
else:
|
||||
raise ValueError(f'Unsupported node type: {self.node.op}')
|
||||
|
||||
return target
|
||||
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
Register different sharding strategies for the current node.
|
||||
|
@ -204,6 +225,29 @@ class NodeHandler(ABC):
|
|||
pass
|
||||
|
||||
|
||||
class MetaInfoNodeHandler(NodeHandler):
|
||||
"""
|
||||
This is a base class to handle the nodes patched in the meta profiler.
|
||||
|
||||
Note: this class will be integrated into the NodeHandler class in the future, after
|
||||
all the functions are patched.
|
||||
"""
|
||||
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
|
||||
return self.strategies_vector
|
||||
|
||||
|
||||
class ModuleHandler(NodeHandler):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
|
@ -221,3 +265,26 @@ class ModuleHandler(NodeHandler):
|
|||
self.module = module
|
||||
self.named_parameters = named_parameters
|
||||
self.named_buffers = named_buffers
|
||||
|
||||
|
||||
class MetaInfoModuleHandler(ModuleHandler):
|
||||
"""
|
||||
This is a base class to handle the module patched in the meta profiler.
|
||||
|
||||
Note: this class will be integrated into the ModuleHandler class in the future, after
|
||||
all the modules are patched.
|
||||
"""
|
||||
|
||||
def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector:
|
||||
"""
|
||||
This method is inherited from NodeHandler. It will register the strategies first,
|
||||
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
|
||||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
strategy.compute_cost = metainfo.compute_cost
|
||||
strategy.memory_cost = metainfo.memory_cost
|
||||
|
||||
return self.strategies_vector
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List
|
|||
import torch
|
||||
|
||||
from ..sharding_strategy import OperationData, OperationDataType
|
||||
from .node_handler import ModuleHandler
|
||||
from .node_handler import MetaInfoModuleHandler, ModuleHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import NormalPoolStrategyGenerator, StrategyGenerator
|
||||
|
||||
|
@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler']
|
|||
@operator_registry.register(torch.nn.AvgPool1d)
|
||||
@operator_registry.register(torch.nn.AvgPool2d)
|
||||
@operator_registry.register(torch.nn.AvgPool3d)
|
||||
class NormPoolingHandler(ModuleHandler):
|
||||
class NormPoolingHandler(MetaInfoModuleHandler):
|
||||
"""
|
||||
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue