Browse Source

[autoparallel] add shard option (#2696)

* [autoparallel] add shard option

* polish
pull/2719/head
YuliangLiu0306 2 years ago committed by GitHub
parent
commit
21d6a48f4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 70
      colossalai/auto_parallel/tensor_shard/initialize.py
  2. 5
      colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
  3. 31
      colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
  4. 17
      colossalai/auto_parallel/tensor_shard/node_handler/option.py
  5. 49
      colossalai/auto_parallel/tensor_shard/options.py
  6. 3
      colossalai/auto_parallel/tensor_shard/solver/__init__.py
  7. 30
      colossalai/auto_parallel/tensor_shard/solver/options.py
  8. 2
      colossalai/auto_parallel/tensor_shard/solver/solver.py
  9. 26
      colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
  10. 9
      tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
  11. 3
      tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py
  12. 14
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
  13. 3
      tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
  14. 9
      tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
  15. 9
      tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py

70
colossalai/auto_parallel/tensor_shard/initialize.py

@ -8,14 +8,9 @@ from torch.fx.graph import Graph
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f
pass
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
shard_option: str):
'''
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
'''
solver_options = SolverOptions()
if solver_preference == 'standard':
solver_preference = SolverPerference.STANDARD
elif solver_preference == 'tp':
solver_preference = SolverPerference.TP
elif solver_preference == 'dp':
solver_preference = SolverPerference.DP
else:
raise ValueError(f'Invalid solver_preference: {solver_preference}')
if dataloader_option == 'replicated':
dataloader_option = DataloaderOption.REPLICATED
elif dataloader_option == 'distributed':
dataloader_option = DataloaderOption.DISTRIBUTED
else:
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
if shard_option == 'standard':
shard_option = ShardOption.STANDARD
elif shard_option == 'shard':
shard_option = ShardOption.SHARD
elif shard_option == 'shard_last_axis':
shard_option = ShardOption.SHARD_LAST_AXIS
elif shard_option == 'full_shard':
shard_option = ShardOption.FULL_SHARD
else:
raise ValueError(f'Invalid shard_option: {shard_option}')
solver_options = SolverOptions(solver_perference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
@ -183,6 +208,9 @@ def initialize_model(model: nn.Module,
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
@ -198,6 +226,12 @@ def initialize_model(model: nn.Module,
the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@ -212,7 +246,12 @@ def initialize_model(model: nn.Module,
graph = tracer.trace(root=model, meta_args=meta_args)
gm = ColoGraphModule(model, graph, model.__class__.__name__)
gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh)
strategies_constructor = build_strategy_constructor(graph,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
if load_solver_solution:
solution = torch.load(solution_path)
else:
@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module,
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module,
rst_to_unpack = initialize_model(model,
meta_args,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solution_path=solver_solution_path,

5
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py

@ -11,7 +11,6 @@ from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .option import ShardOption
from .output_handler import OutputHandler
from .permute_handler import PermuteHandler
from .placeholder_handler import PlaceholderHandler
@ -31,6 +30,6 @@ __all__ = [
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption',
'TransposeHandler', 'SplitHandler'
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
'SplitHandler'
]

31
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py

@ -5,7 +5,7 @@ import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@ -32,19 +32,19 @@ class NodeHandler(ABC):
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
'''
def __init__(
self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
) -> None:
def __init__(self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
shard_option: ShardOption = ShardOption.STANDARD,
solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.shard_option = shard_option
self.solver_perference = solver_perference
def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
"""
@ -187,15 +187,24 @@ class NodeHandler(ABC):
remove_strategy_list = []
for strategy in self.strategies_vector:
shard_level = 0
shard_axis_list = []
last_axis = len(self.device_mesh.mesh_shape) - 1
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axis in sharding_spec.dim_partition_dict.items():
shard_level += len(shard_axis)
for dim, shard_axes in sharding_spec.dim_partition_dict.items():
for shard_axis in shard_axes:
if shard_axis not in shard_axis_list:
shard_axis_list.append(shard_axis)
shard_level = len(shard_axis_list)
using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list
if self.shard_option == ShardOption.SHARD and shard_level == 0:
remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.SHARD_LAST_AXIS:
if shard_level != 1 or using_last_axis == False:
remove_strategy_list.append(strategy)
for strategy in remove_strategy_list:
self.strategies_vector.remove(strategy)

17
colossalai/auto_parallel/tensor_shard/node_handler/option.py

@ -1,17 +0,0 @@
from enum import Enum
__all__ = ['ShardOption']
class ShardOption(Enum):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
"""
STANDARD = 0
SHARD = 1
FULL_SHARD = 2

49
colossalai/auto_parallel/tensor_shard/options.py

@ -0,0 +1,49 @@
from dataclasses import dataclass
from enum import Enum
__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
STANDARD = 0
DP = 1
TP = 2
class ShardOption(Enum):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
"""
STANDARD = 0
SHARD = 1
SHARD_LAST_AXIS = 2
FULL_SHARD = 3
class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
REPLICATED = 0
DISTRIBUTED = 1
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
shard_option: ShardOption = ShardOption.STANDARD

3
colossalai/auto_parallel/tensor_shard/solver/__init__.py

@ -1,7 +1,6 @@
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .options import SolverOptions
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions']
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']

30
colossalai/auto_parallel/tensor_shard/solver/options.py

@ -1,30 +0,0 @@
from dataclasses import dataclass
from enum import Enum
__all__ = ['SolverOptions']
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
STANDARD = 0
DP = 1
TP = 2
class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
REPLICATED = 0
DISTRIBUTED = 1
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED

2
colossalai/auto_parallel/tensor_shard/solver/solver.py

@ -33,7 +33,7 @@ class Solver:
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3,
verbose=True):
verbose=False):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:

26
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py

@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVe
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.device.device_mesh import DeviceMesh
from .options import DataloaderOption, SolverOptions
from ..options import DataloaderOption, SolverOptions
__all__ = ['StrategiesConstructor']
@ -101,7 +101,11 @@ class StrategiesConstructor:
# get_attr node
elif node.op == 'get_attr':
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
getattr_handler = GetattrHandler(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
getattr_handler.register_strategy()
# call_module node
@ -109,7 +113,11 @@ class StrategiesConstructor:
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
handler = operator_registry.get(submod_type)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
@ -118,7 +126,11 @@ class StrategiesConstructor:
# call_function node
elif node.op == 'call_function':
target = node.target
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
handler = operator_registry.get(target)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
@ -127,7 +139,11 @@ class StrategiesConstructor:
# call_method node
elif node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
handler = operator_registry.get(method)(node,
self.device_mesh,
strategies_vector,
shard_option=self.solver_options.shard_option,
solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):

9
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py

@ -4,13 +4,8 @@ import transformers
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager

3
tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py

@ -7,8 +7,9 @@ from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer

14
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py

@ -5,7 +5,7 @@ import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption
from colossalai.auto_parallel.tensor_shard.options import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@ -49,6 +49,15 @@ def check_shard_option(shard_option):
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
if shard_option == ShardOption.SHARD_LAST_AXIS:
# RR = RS x SR
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS1 = RR x RS1' in strategy_name_list
return
# SS = SR x RS
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
@ -104,7 +113,8 @@ def check_shard_option(shard_option):
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_shard_option():
for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD]:
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
for shard_option in [ShardOption.SHARD_LAST_AXIS]:
check_shard_option(shard_option)

3
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py

@ -6,7 +6,8 @@ from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver

9
tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py

@ -1,13 +1,8 @@
import torch
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag

9
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py

@ -3,13 +3,8 @@ from torch.fx import GraphModule
from torchvision.models import resnet50
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager

Loading…
Cancel
Save