[autoparallel] added new strategy constructor template (#1661)

* [autoparallel] added new strategy constructor template

* polish code
pull/1652/head^2
Frank Lee 2022-09-28 14:01:36 +08:00 committed by GitHub
parent 3a4d6f63a8
commit 9ec401a722
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 5 deletions

View File

@ -6,8 +6,9 @@ from .reshape_handler import ReshapeHandler
from .bcast_op_handler import BcastOpHandler
from .embedding_handler import EmbeddingHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler
__all__ = [
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
'UnaryElementwiseHandler', 'EmbeddingHandler'
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler'
]

View File

@ -14,7 +14,7 @@ class Registry:
return wrapper
def get(self, source):
assert source in self.store
assert source in self.store, f'{source} not found in the {self.name} registry'
target = self.store[source]
return target

View File

@ -49,9 +49,10 @@ class OperationDataType(Enum):
"""
An operation can come from the argument list of an operator or the parameter list of a module.
"""
ARG = 0
PARAM = 1
OUTPUT = 2
INPUT = 0
ARG = 1
PARAM = 2
OUTPUT = 3
@dataclass

View File

@ -4,6 +4,7 @@ from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerN
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.auto_parallel.solver.op_handler.registry import operator_registry
from .options import SolverOptions
from . import ShardingStrategy, StrategiesVector
from .op_handler import *
@ -16,6 +17,8 @@ from typing import Dict, List
from ._utils import generate_sharding_spec, generate_resharding_costs
import builtins
__all__ = ['StrategiesConstructor', 'StrategiesConstructor_V2']
class StrategiesConstructor:
"""
@ -49,6 +52,7 @@ class StrategiesConstructor:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
@ -394,3 +398,87 @@ class StrategiesConstructor:
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
class StrategiesConstructor_V2:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def build_strategies_and_cost(self):
"""
This method is to build the strategy vector for each node in the computation graph.
"""
for node in self.nodes:
strategies_vector = StrategiesVector(node)
# placeholder node
if node.op == 'placeholder':
# TODO: implement placeholder node handler
pass
# get_attr node
elif node.op == 'get_attr':
# TODO: implement getattr node handler
pass
# call_module node
elif node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
handler.register_strategy()
# call_function node
elif node.op == 'call_function':
target = node.target
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
handler.register_strategy()
# call_method node
elif node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
handler.register_strategy()
# output node
elif node.op == 'output':
# TODO: implement output node handler
pass
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector