mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added new strategy constructor template (#1661)
* [autoparallel] added new strategy constructor template * polish codepull/1652/head^2
parent
3a4d6f63a8
commit
9ec401a722
|
@ -6,8 +6,9 @@ from .reshape_handler import ReshapeHandler
|
||||||
from .bcast_op_handler import BcastOpHandler
|
from .bcast_op_handler import BcastOpHandler
|
||||||
from .embedding_handler import EmbeddingHandler
|
from .embedding_handler import EmbeddingHandler
|
||||||
from .unary_elementwise_handler import UnaryElementwiseHandler
|
from .unary_elementwise_handler import UnaryElementwiseHandler
|
||||||
|
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
|
||||||
'UnaryElementwiseHandler', 'EmbeddingHandler'
|
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler'
|
||||||
]
|
]
|
||||||
|
|
|
@ -14,7 +14,7 @@ class Registry:
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def get(self, source):
|
def get(self, source):
|
||||||
assert source in self.store
|
assert source in self.store, f'{source} not found in the {self.name} registry'
|
||||||
target = self.store[source]
|
target = self.store[source]
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
|
@ -49,9 +49,10 @@ class OperationDataType(Enum):
|
||||||
"""
|
"""
|
||||||
An operation can come from the argument list of an operator or the parameter list of a module.
|
An operation can come from the argument list of an operator or the parameter list of a module.
|
||||||
"""
|
"""
|
||||||
ARG = 0
|
INPUT = 0
|
||||||
PARAM = 1
|
ARG = 1
|
||||||
OUTPUT = 2
|
PARAM = 2
|
||||||
|
OUTPUT = 3
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -4,6 +4,7 @@ from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerN
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.auto_parallel.solver.op_handler.registry import operator_registry
|
||||||
from .options import SolverOptions
|
from .options import SolverOptions
|
||||||
from . import ShardingStrategy, StrategiesVector
|
from . import ShardingStrategy, StrategiesVector
|
||||||
from .op_handler import *
|
from .op_handler import *
|
||||||
|
@ -16,6 +17,8 @@ from typing import Dict, List
|
||||||
from ._utils import generate_sharding_spec, generate_resharding_costs
|
from ._utils import generate_sharding_spec, generate_resharding_costs
|
||||||
import builtins
|
import builtins
|
||||||
|
|
||||||
|
__all__ = ['StrategiesConstructor', 'StrategiesConstructor_V2']
|
||||||
|
|
||||||
|
|
||||||
class StrategiesConstructor:
|
class StrategiesConstructor:
|
||||||
"""
|
"""
|
||||||
|
@ -49,6 +52,7 @@ class StrategiesConstructor:
|
||||||
name_checklist.append(strategy.name)
|
name_checklist.append(strategy.name)
|
||||||
else:
|
else:
|
||||||
remove_list.append(strategy)
|
remove_list.append(strategy)
|
||||||
|
|
||||||
for strategy in remove_list:
|
for strategy in remove_list:
|
||||||
strategies_vector.remove(strategy)
|
strategies_vector.remove(strategy)
|
||||||
|
|
||||||
|
@ -394,3 +398,87 @@ class StrategiesConstructor:
|
||||||
setattr(node, 'strategies_vector', strategies_vector)
|
setattr(node, 'strategies_vector', strategies_vector)
|
||||||
self.leaf_strategies.append(strategies_vector)
|
self.leaf_strategies.append(strategies_vector)
|
||||||
self.strategy_map[node] = strategies_vector
|
self.strategy_map[node] = strategies_vector
|
||||||
|
|
||||||
|
|
||||||
|
class StrategiesConstructor_V2:
|
||||||
|
"""
|
||||||
|
StrategiesConstructor is used to construct the parallelization plan for the model execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph (Graph): a Graph object used for analysis and strategy generation.
|
||||||
|
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||||
|
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
|
||||||
|
self.graph = graph
|
||||||
|
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||||
|
self.root_module = self.graph.owning_module
|
||||||
|
self.nodes = list(graph.nodes)
|
||||||
|
self.device_mesh = device_mesh
|
||||||
|
self.leaf_strategies = []
|
||||||
|
self.strategy_map = {}
|
||||||
|
self.solver_options = solver_options
|
||||||
|
|
||||||
|
def remove_duplicated_strategy(self, strategies_vector):
|
||||||
|
'''
|
||||||
|
In build_strategies_and_cost method, we may produce some duplicated strategies.
|
||||||
|
In this method, we will remove the duplicated strategies depending on the strategies name.
|
||||||
|
Note that this operation is in-place.
|
||||||
|
'''
|
||||||
|
name_checklist = []
|
||||||
|
remove_list = []
|
||||||
|
for strategy in strategies_vector:
|
||||||
|
if strategy.name not in name_checklist:
|
||||||
|
name_checklist.append(strategy.name)
|
||||||
|
else:
|
||||||
|
remove_list.append(strategy)
|
||||||
|
for strategy in remove_list:
|
||||||
|
strategies_vector.remove(strategy)
|
||||||
|
|
||||||
|
def build_strategies_and_cost(self):
|
||||||
|
"""
|
||||||
|
This method is to build the strategy vector for each node in the computation graph.
|
||||||
|
"""
|
||||||
|
for node in self.nodes:
|
||||||
|
strategies_vector = StrategiesVector(node)
|
||||||
|
|
||||||
|
# placeholder node
|
||||||
|
if node.op == 'placeholder':
|
||||||
|
# TODO: implement placeholder node handler
|
||||||
|
pass
|
||||||
|
|
||||||
|
# get_attr node
|
||||||
|
elif node.op == 'get_attr':
|
||||||
|
# TODO: implement getattr node handler
|
||||||
|
pass
|
||||||
|
|
||||||
|
# call_module node
|
||||||
|
elif node.op == 'call_module':
|
||||||
|
target = node.target
|
||||||
|
submod = self.root_module.get_submodule(target)
|
||||||
|
submod_type = type(submod)
|
||||||
|
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
|
||||||
|
handler.register_strategy()
|
||||||
|
|
||||||
|
# call_function node
|
||||||
|
elif node.op == 'call_function':
|
||||||
|
target = node.target
|
||||||
|
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
|
||||||
|
handler.register_strategy()
|
||||||
|
|
||||||
|
# call_method node
|
||||||
|
elif node.op == 'call_method':
|
||||||
|
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||||
|
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
|
||||||
|
handler.register_strategy()
|
||||||
|
|
||||||
|
# output node
|
||||||
|
elif node.op == 'output':
|
||||||
|
# TODO: implement output node handler
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.remove_duplicated_strategy(strategies_vector)
|
||||||
|
setattr(node, 'strategies_vector', strategies_vector)
|
||||||
|
self.leaf_strategies.append(strategies_vector)
|
||||||
|
self.strategy_map[node] = strategies_vector
|
||||||
|
|
Loading…
Reference in New Issue