mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added solver option dataclass (#1588)
parent
82d4376c23
commit
219f66c571
|
@ -4,5 +4,9 @@ from .solver import Solver
|
||||||
from .cost_graph import CostGraph
|
from .cost_graph import CostGraph
|
||||||
from .strategies_constructor import StrategiesConstructor
|
from .strategies_constructor import StrategiesConstructor
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
from .options import SolverOptions
|
||||||
|
|
||||||
__all__ = ['StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
|
__all__ = [
|
||||||
|
'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph',
|
||||||
|
'SolverOptions'
|
||||||
|
]
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
__all__ = ['SolverOptions']
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SolverOptions:
|
||||||
|
"""
|
||||||
|
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
|
||||||
|
"""
|
||||||
|
fast: bool = False
|
|
@ -1,5 +1,8 @@
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from .options import SolverOptions
|
||||||
from . import ShardingStrategy, StrategiesVector
|
from . import ShardingStrategy, StrategiesVector
|
||||||
from .op_handler import *
|
from .op_handler import *
|
||||||
from .constants import *
|
from .constants import *
|
||||||
|
@ -11,9 +14,20 @@ from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
class StrategiesConstructor:
|
class StrategiesConstructor:
|
||||||
|
"""
|
||||||
|
StrategiesConstructor is used to construct the parallelization plan for the model execution.
|
||||||
|
|
||||||
def __init__(self, graph, device_mesh, shape_consistency_manager, solver_options):
|
Args:
|
||||||
|
graph (Graph): a Graph object used for analysis and strategy generation.
|
||||||
|
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
|
||||||
|
shape_consistency_manager (ShapeConsistencyManager): a ShapeConsistencyManager object to make sure the sharding specs are consistent.
|
||||||
|
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, device_mesh: DeviceMesh, shape_consistency_manager: ShapeConsistencyManager,
|
||||||
|
solver_options: SolverOptions):
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
|
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
|
||||||
self.root_module = self.graph.owning_module
|
self.root_module = self.graph.owning_module
|
||||||
self.nodes = list(graph.nodes)
|
self.nodes = list(graph.nodes)
|
||||||
self.device_mesh = device_mesh
|
self.device_mesh = device_mesh
|
||||||
|
@ -77,13 +91,13 @@ class StrategiesConstructor:
|
||||||
strategies_vector = StrategiesVector(node)
|
strategies_vector = StrategiesVector(node)
|
||||||
# placeholder node
|
# placeholder node
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
# For placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
|
# For placeholder nodes, if solver_options.fast is True, we just let them in
|
||||||
# fully replicate status, then strategies of following node will be treated equally due
|
# fully replicate status, then strategies of following node will be treated equally due
|
||||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||||
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
|
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
|
||||||
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
|
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
|
||||||
|
|
||||||
if self.solver_options['fast_mode']:
|
if self.solver_options.fast:
|
||||||
# create sharding strategy for placeholder
|
# create sharding strategy for placeholder
|
||||||
name = 'Replica Placeholder'
|
name = 'Replica Placeholder'
|
||||||
dim_partition_dict = {}
|
dim_partition_dict = {}
|
||||||
|
@ -97,12 +111,12 @@ class StrategiesConstructor:
|
||||||
|
|
||||||
# get_attr node
|
# get_attr node
|
||||||
if node.op == 'get_attr':
|
if node.op == 'get_attr':
|
||||||
# Same as placeholder nodes, if solver_options['fast_mode'] is True, we just let them in
|
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
|
||||||
# fully replicate status, then strategies of following node will be treated equally due
|
# fully replicate status, then strategies of following node will be treated equally due
|
||||||
# to replicate status has no resharding cost to other status. At the same time, the searching
|
# to replicate status has no resharding cost to other status. At the same time, the searching
|
||||||
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
|
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
|
||||||
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
|
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
|
||||||
if self.solver_options['fast_mode']:
|
if self.solver_options.fast:
|
||||||
# create sharding strategy for get_attr
|
# create sharding strategy for get_attr
|
||||||
name = 'Replica Attribute'
|
name = 'Replica Attribute'
|
||||||
dim_partition_dict = {}
|
dim_partition_dict = {}
|
||||||
|
@ -382,7 +396,7 @@ class StrategiesConstructor:
|
||||||
|
|
||||||
# output node
|
# output node
|
||||||
if node.op == 'output':
|
if node.op == 'output':
|
||||||
if self.solver_options['fast_mode']:
|
if self.solver_options.fast:
|
||||||
# create sharding strategy for output
|
# create sharding strategy for output
|
||||||
name = 'Replica Output'
|
name = 'Replica Output'
|
||||||
input_nodes = strategies_vector.predecessor_nodes
|
input_nodes = strategies_vector.predecessor_nodes
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from pickletools import optimize
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -10,6 +11,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||||
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,7 +54,7 @@ def test_cost_graph():
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
solver_options = {'fast_mode': True}
|
solver_options = SolverOptions(fast=True)
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from colossalai.auto_parallel.solver import Solver
|
from colossalai.auto_parallel.solver import Solver
|
||||||
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
|
|
||||||
|
|
||||||
class ConvModel(nn.Module):
|
class ConvModel(nn.Module):
|
||||||
|
@ -39,7 +40,6 @@ def test_solver():
|
||||||
# [[0, 1]
|
# [[0, 1]
|
||||||
# [2, 3]]
|
# [2, 3]]
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
entire_shape = torch.Size((4, 16, 64, 64))
|
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
|
@ -57,9 +57,8 @@ def test_solver():
|
||||||
# return relu
|
# return relu
|
||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
|
||||||
|
|
||||||
solver_options = {'fast_mode': True}
|
solver_options = SolverOptions(fast=True)
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||||
|
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +48,7 @@ def test_strategies_constructor():
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
solver_options = {'fast_mode': True}
|
solver_options = SolverOptions(fast=True)
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
|
||||||
|
|
||||||
assert strategies_constructor.leaf_strategies == []
|
assert strategies_constructor.leaf_strategies == []
|
||||||
|
|
Loading…
Reference in New Issue