mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] autoparallel initialize (#2238)
parent
85178a397a
commit
8897b8f753
@ -0,0 +1,255 @@
|
|||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from torch.fx.graph import Graph
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||||
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
||||||
|
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||||
|
CostGraph,
|
||||||
|
GraphAnalyser,
|
||||||
|
Solver,
|
||||||
|
SolverOptions,
|
||||||
|
StrategiesConstructor,
|
||||||
|
)
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.device.profile_alpha_beta import profile_alpha_beta
|
||||||
|
from colossalai.fx.tracer import ColoTracer
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleWrapper(nn.Module):
|
||||||
|
'''
|
||||||
|
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
|
||||||
|
into the forward function.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
|
||||||
|
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
module: the original module
|
||||||
|
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
|
||||||
|
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
|
||||||
|
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
|
||||||
|
'''
|
||||||
|
super(ModuleWrapper, self).__init__()
|
||||||
|
self.module = module
|
||||||
|
self.sharding_spec_dict = sharding_spec_dict
|
||||||
|
self.origin_spec_dict = origin_spec_dict
|
||||||
|
self.comm_actions_dict = comm_actions_dict
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.module(*args,
|
||||||
|
sharding_spec_convert_dict=self.sharding_spec_dict,
|
||||||
|
origin_node_sharding_spec_dict=self.origin_spec_dict,
|
||||||
|
comm_actions_dict=self.comm_actions_dict,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
|
||||||
|
'''
|
||||||
|
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
|
||||||
|
'''
|
||||||
|
# TODO: implement this function
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
|
||||||
|
'''
|
||||||
|
This method is used to search the best logical mesh shape for the given world size
|
||||||
|
based on the alpha_beta_dict.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
|
||||||
|
'''
|
||||||
|
# TODO: implement this function
|
||||||
|
return (world_size, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
|
||||||
|
'''
|
||||||
|
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
|
||||||
|
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
|
||||||
|
'''
|
||||||
|
# TODO: implement this function
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
|
||||||
|
'''
|
||||||
|
This method is used to build the strategy_constructor for the given graph.
|
||||||
|
After this method, each node in the graph will have a strategies_vector which
|
||||||
|
is constructed by the related node handler.
|
||||||
|
'''
|
||||||
|
solver_options = SolverOptions()
|
||||||
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
|
strategies_constructor.build_strategies_and_cost()
|
||||||
|
|
||||||
|
return strategies_constructor
|
||||||
|
|
||||||
|
|
||||||
|
def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
|
||||||
|
'''
|
||||||
|
This method is used to solve the best solution for the given graph.
|
||||||
|
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
|
||||||
|
'''
|
||||||
|
graph_analyser = GraphAnalyser(gm)
|
||||||
|
liveness_list = graph_analyser.liveness_analysis()
|
||||||
|
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
|
||||||
|
cost_graph.simplify_graph()
|
||||||
|
solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
|
||||||
|
ret = solver.call_solver_serialized_args()
|
||||||
|
solution = list(ret[0])
|
||||||
|
|
||||||
|
return solution
|
||||||
|
|
||||||
|
|
||||||
|
def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh,
|
||||||
|
strategies_constructor: StrategiesConstructor):
|
||||||
|
'''
|
||||||
|
This method is used to transform the original graph to the sharded graph.
|
||||||
|
The model parameters will be sharded according to the solution and the grad hooks
|
||||||
|
will be added to the sharded graph using the runtime_preparation_pass.
|
||||||
|
The communication node will be added into the graph using the runtime_apply_pass.
|
||||||
|
'''
|
||||||
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||||
|
gm, solution, device_mesh, strategies_constructor)
|
||||||
|
gm = runtime_apply_pass(gm)
|
||||||
|
gm.recompile()
|
||||||
|
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||||
|
|
||||||
|
return gm, sharding_spec_dicts
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_device_mesh(world_size: int = -1,
|
||||||
|
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||||
|
logical_mesh_shape: Tuple[int] = None):
|
||||||
|
'''
|
||||||
|
This method is used to initialize the device mesh.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
world_size(optional): the size of device mesh. If the world_size is -1,
|
||||||
|
the world size will be set to the number of GPUs in the current machine.
|
||||||
|
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
|
||||||
|
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
|
||||||
|
generated by profile_alpha_beta function.
|
||||||
|
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||||
|
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||||
|
generated by search_best_logical_mesh_shape function.
|
||||||
|
'''
|
||||||
|
# if world_size is not set, use the world size from torch.distributed
|
||||||
|
if world_size == -1:
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
device1d = [i for i in range(world_size)]
|
||||||
|
|
||||||
|
if alpha_beta_dict is None:
|
||||||
|
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
|
||||||
|
alpha_beta_dict = profile_alpha_beta(device1d)
|
||||||
|
|
||||||
|
if logical_mesh_shape is None:
|
||||||
|
# search for the best logical mesh shape
|
||||||
|
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
|
||||||
|
|
||||||
|
# extract alpha and beta values for the chosen logical mesh shape
|
||||||
|
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
|
||||||
|
physical_mesh = torch.tensor(device1d)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
|
||||||
|
mesh_shape=logical_mesh_shape,
|
||||||
|
mesh_alpha=mesh_alpha,
|
||||||
|
mesh_beta=mesh_beta,
|
||||||
|
init_process_group=True)
|
||||||
|
return device_mesh
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_model(model: nn.Module,
|
||||||
|
meta_args: Dict[str, torch.Tensor],
|
||||||
|
device_mesh: DeviceMesh,
|
||||||
|
memory_budget: float = -1.0,
|
||||||
|
save_solver_solution: bool = False,
|
||||||
|
load_solver_solution: bool = False,
|
||||||
|
solution_path: str = None):
|
||||||
|
'''
|
||||||
|
This method is used to initialize the sharded model which could be used as normal pytorch model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: the model to be sharded.
|
||||||
|
meta_args: the meta_args is used to specify the input shapes of the model.
|
||||||
|
device_mesh: the device mesh to execute the model.
|
||||||
|
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
||||||
|
the memory budget will be infinity.
|
||||||
|
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||||
|
to the solution_path.
|
||||||
|
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||||
|
from the solution_path.
|
||||||
|
solution_path(optional): the path to save or load the solution.
|
||||||
|
'''
|
||||||
|
tracer = ColoTracer()
|
||||||
|
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
strategies_constructor = build_strategy_constructor(graph, device_mesh)
|
||||||
|
if load_solver_solution:
|
||||||
|
solution = torch.load(solution_path)
|
||||||
|
else:
|
||||||
|
solution = solve_solution(gm, strategies_constructor, memory_budget)
|
||||||
|
if save_solver_solution:
|
||||||
|
torch.save(solution, solution_path)
|
||||||
|
|
||||||
|
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
|
||||||
|
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
||||||
|
|
||||||
|
return model_to_return
|
||||||
|
|
||||||
|
|
||||||
|
def autoparallelize(model: nn.Module,
|
||||||
|
meta_args: Dict[str, torch.Tensor] = None,
|
||||||
|
data_loader: torch.utils.data.DataLoader = None,
|
||||||
|
data_process_func: callable = None,
|
||||||
|
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||||
|
logical_mesh_shape: Tuple[int] = None,
|
||||||
|
save_solver_solution: bool = False,
|
||||||
|
load_solver_solution: bool = False,
|
||||||
|
solver_solution_path: str = None,
|
||||||
|
memory_budget: float = -1.0):
|
||||||
|
'''
|
||||||
|
This method is used to initialize the device mesh, extract the meta_args, and
|
||||||
|
use them to create a sharded model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: the model to be sharded.
|
||||||
|
meta_args(optional): the meta_args is used to specify the input shapes of the model.
|
||||||
|
If the meta_args is None, the meta_args will be extracted from the data_loader.
|
||||||
|
data_loader(optional): the data_loader to be used in normal training loop.
|
||||||
|
data_process_func(optional): the data_process_func is used to process the data from the data_loader.
|
||||||
|
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
|
||||||
|
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
|
||||||
|
generated by profile_alpha_beta function.
|
||||||
|
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||||
|
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||||
|
generated by search_best_logical_mesh_shape function.
|
||||||
|
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||||
|
to the solution_path.
|
||||||
|
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||||
|
from the solution_path.
|
||||||
|
solver_solution_path(optional): the path to save or load the solution.
|
||||||
|
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
||||||
|
the memory budget will be infinity.
|
||||||
|
'''
|
||||||
|
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
|
||||||
|
if meta_args is None:
|
||||||
|
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
|
||||||
|
model = initialize_model(model,
|
||||||
|
meta_args,
|
||||||
|
device_mesh,
|
||||||
|
save_solver_solution=save_solver_solution,
|
||||||
|
load_solver_solution=load_solver_solution,
|
||||||
|
solver_solution_path=solver_solution_path,
|
||||||
|
memory_budget=memory_budget)
|
||||||
|
|
||||||
|
return model
|
Loading…
Reference in new issue