mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added new strategy constructor template (#1661)
* [autoparallel] added new strategy constructor template * polish codepull/1652/head^2
parent
3a4d6f63a8
commit
9ec401a722
|
@ -6,8 +6,9 @@ from .reshape_handler import ReshapeHandler
|
|||
from .bcast_op_handler import BcastOpHandler
|
||||
from .embedding_handler import EmbeddingHandler
|
||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler
|
||||
|
||||
__all__ = [
|
||||
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
||||
'UnaryElementwiseHandler', 'EmbeddingHandler'
|
||||
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler'
|
||||
]
|
||||
|
|
|
@ -14,7 +14,7 @@ class Registry:
|
|||
return wrapper
|
||||
|
||||
def get(self, source):
|
||||
assert source in self.store
|
||||
assert source in self.store, f'{source} not found in the {self.name} registry'
|
||||
target = self.store[source]
|
||||
return target
|
||||
|
||||
|
|
|
@ -49,9 +49,10 @@ class OperationDataType(Enum):
|
|||
"""
|
||||
An operation can come from the argument list of an operator or the parameter list of a module.
|
||||
"""
|
||||
ARG = 0
|
||||
PARAM = 1
|
||||
OUTPUT = 2
|
||||
INPUT = 0
|
||||
ARG = 1
|
||||
PARAM = 2
|
||||
OUTPUT = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -4,6 +4,7 @@ from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerN
|
|||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.auto_parallel.solver.op_handler.registry import operator_registry
|
||||
from .options import SolverOptions
|
||||
from . import ShardingStrategy, StrategiesVector
|
||||
from .op_handler import *
|
||||
|
@ -16,6 +17,8 @@ from typing import Dict, List
|
|||
from ._utils import generate_sharding_spec, generate_resharding_costs
|
||||
import builtins
|
||||
|
||||
__all__ = ['StrategiesConstructor', 'StrategiesConstructor_V2']
|
||||
|
||||
|
||||
class StrategiesConstructor:
|
||||
"""
|
||||
|
@ -49,6 +52,7 @@ class StrategiesConstructor:
|
|||
name_checklist.append(strategy.name)
|
||||
else:
|
||||
remove_list.append(strategy)
|
||||
|
||||
for strategy in remove_list:
|
||||
strategies_vector.remove(strategy)
|
||||
|
||||
|
@ -394,3 +398,87 @@ class StrategiesConstructor:
|
|||
setattr(node, 'strategies_vector', strategies_vector)
|
||||
self.leaf_strategies.append(strategies_vector)
|
||||
self.strategy_map[node] = strategies_vector
|
||||
|
||||
|
||||
class StrategiesConstructor_V2:
|
||||
"""
|
||||
StrategiesConstructor is used to construct the parallelization plan for the model execution.
|
||||
|
||||
Args:
|
||||
graph (Graph): a Graph object used for analysis and strategy generation.
|
||||
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
|
||||
self.graph = graph
|
||||
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||
self.root_module = self.graph.owning_module
|
||||
self.nodes = list(graph.nodes)
|
||||
self.device_mesh = device_mesh
|
||||
self.leaf_strategies = []
|
||||
self.strategy_map = {}
|
||||
self.solver_options = solver_options
|
||||
|
||||
def remove_duplicated_strategy(self, strategies_vector):
|
||||
'''
|
||||
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||
In this method, we will remove the duplicated strategies depending on the strategies name.
|
||||
Note that this operation is in-place.
|
||||
'''
|
||||
name_checklist = []
|
||||
remove_list = []
|
||||
for strategy in strategies_vector:
|
||||
if strategy.name not in name_checklist:
|
||||
name_checklist.append(strategy.name)
|
||||
else:
|
||||
remove_list.append(strategy)
|
||||
for strategy in remove_list:
|
||||
strategies_vector.remove(strategy)
|
||||
|
||||
def build_strategies_and_cost(self):
|
||||
"""
|
||||
This method is to build the strategy vector for each node in the computation graph.
|
||||
"""
|
||||
for node in self.nodes:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
|
||||
# placeholder node
|
||||
if node.op == 'placeholder':
|
||||
# TODO: implement placeholder node handler
|
||||
pass
|
||||
|
||||
# get_attr node
|
||||
elif node.op == 'get_attr':
|
||||
# TODO: implement getattr node handler
|
||||
pass
|
||||
|
||||
# call_module node
|
||||
elif node.op == 'call_module':
|
||||
target = node.target
|
||||
submod = self.root_module.get_submodule(target)
|
||||
submod_type = type(submod)
|
||||
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
|
||||
handler.register_strategy()
|
||||
|
||||
# call_function node
|
||||
elif node.op == 'call_function':
|
||||
target = node.target
|
||||
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
|
||||
handler.register_strategy()
|
||||
|
||||
# call_method node
|
||||
elif node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
|
||||
handler.register_strategy()
|
||||
|
||||
# output node
|
||||
elif node.op == 'output':
|
||||
# TODO: implement output node handler
|
||||
pass
|
||||
|
||||
self.remove_duplicated_strategy(strategies_vector)
|
||||
setattr(node, 'strategies_vector', strategies_vector)
|
||||
self.leaf_strategies.append(strategies_vector)
|
||||
self.strategy_map[node] = strategies_vector
|
||||
|
|
Loading…
Reference in New Issue