diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py new file mode 100644 index 000000000..f9725043e --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -0,0 +1,255 @@ +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.fx import GraphModule +from torch.fx.graph import Graph + +from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass +from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction +from colossalai.auto_parallel.tensor_shard.solver import ( + CostGraph, + GraphAnalyser, + Solver, + SolverOptions, + StrategiesConstructor, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.device.profile_alpha_beta import profile_alpha_beta +from colossalai.fx.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec + + +class ModuleWrapper(nn.Module): + ''' + This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict + into the forward function. + ''' + + def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], + origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): + ''' + Args: + module: the original module + sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node. + origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor. + comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor. + ''' + super(ModuleWrapper, self).__init__() + self.module = module + self.sharding_spec_dict = sharding_spec_dict + self.origin_spec_dict = origin_spec_dict + self.comm_actions_dict = comm_actions_dict + + def forward(self, *args, **kwargs): + return self.module(*args, + sharding_spec_convert_dict=self.sharding_spec_dict, + origin_node_sharding_spec_dict=self.origin_spec_dict, + comm_actions_dict=self.comm_actions_dict, + **kwargs) + + +def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable): + ''' + This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func. + ''' + # TODO: implement this function + pass + + +def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]): + ''' + This method is used to search the best logical mesh shape for the given world size + based on the alpha_beta_dict. + + For example: + if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1). + ''' + # TODO: implement this function + return (world_size, 1) + + +def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]): + ''' + This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape + from the alpha_beta_dict. These two values will be used to estimate the communication cost. + ''' + # TODO: implement this function + pass + + +def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh): + ''' + This method is used to build the strategy_constructor for the given graph. + After this method, each node in the graph will have a strategies_vector which + is constructed by the related node handler. + ''' + solver_options = SolverOptions() + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) + strategies_constructor.build_strategies_and_cost() + + return strategies_constructor + + +def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): + ''' + This method is used to solve the best solution for the given graph. + The solution is a list of integers, each integer represents the best strategy index of the corresponding node. + ''' + graph_analyser = GraphAnalyser(gm) + liveness_list = graph_analyser.liveness_analysis() + cost_graph = CostGraph(strategy_constructor.leaf_strategies) + cost_graph.simplify_graph() + solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget) + ret = solver.call_solver_serialized_args() + solution = list(ret[0]) + + return solution + + +def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor): + ''' + This method is used to transform the original graph to the sharded graph. + The model parameters will be sharded according to the solution and the grad hooks + will be added to the sharded graph using the runtime_preparation_pass. + The communication node will be added into the graph using the runtime_apply_pass. + ''' + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor) + gm = runtime_apply_pass(gm) + gm.recompile() + sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict) + + return gm, sharding_spec_dicts + + +def initialize_device_mesh(world_size: int = -1, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None): + ''' + This method is used to initialize the device mesh. + + Args: + world_size(optional): the size of device mesh. If the world_size is -1, + the world size will be set to the number of GPUs in the current machine. + alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values + for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be + generated by profile_alpha_beta function. + logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical + mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be + generated by search_best_logical_mesh_shape function. + ''' + # if world_size is not set, use the world size from torch.distributed + if world_size == -1: + world_size = dist.get_world_size() + device1d = [i for i in range(world_size)] + + if alpha_beta_dict is None: + # if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device + alpha_beta_dict = profile_alpha_beta(device1d) + + if logical_mesh_shape is None: + # search for the best logical mesh shape + logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict) + + # extract alpha and beta values for the chosen logical mesh shape + mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape) + physical_mesh = torch.tensor(device1d) + device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, + mesh_shape=logical_mesh_shape, + mesh_alpha=mesh_alpha, + mesh_beta=mesh_beta, + init_process_group=True) + return device_mesh + + +def initialize_model(model: nn.Module, + meta_args: Dict[str, torch.Tensor], + device_mesh: DeviceMesh, + memory_budget: float = -1.0, + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solution_path: str = None): + ''' + This method is used to initialize the sharded model which could be used as normal pytorch model. + + Args: + model: the model to be sharded. + meta_args: the meta_args is used to specify the input shapes of the model. + device_mesh: the device mesh to execute the model. + memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, + the memory budget will be infinity. + save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved + to the solution_path. + load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded + from the solution_path. + solution_path(optional): the path to save or load the solution. + ''' + tracer = ColoTracer() + + graph = tracer.trace(root=model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + strategies_constructor = build_strategy_constructor(graph, device_mesh) + if load_solver_solution: + solution = torch.load(solution_path) + else: + solution = solve_solution(gm, strategies_constructor, memory_budget) + if save_solver_solution: + torch.save(solution, solution_path) + + gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor) + model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) + + return model_to_return + + +def autoparallelize(model: nn.Module, + meta_args: Dict[str, torch.Tensor] = None, + data_loader: torch.utils.data.DataLoader = None, + data_process_func: callable = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solver_solution_path: str = None, + memory_budget: float = -1.0): + ''' + This method is used to initialize the device mesh, extract the meta_args, and + use them to create a sharded model. + + Args: + model: the model to be sharded. + meta_args(optional): the meta_args is used to specify the input shapes of the model. + If the meta_args is None, the meta_args will be extracted from the data_loader. + data_loader(optional): the data_loader to be used in normal training loop. + data_process_func(optional): the data_process_func is used to process the data from the data_loader. + alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values + for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be + generated by profile_alpha_beta function. + logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical + mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be + generated by search_best_logical_mesh_shape function. + save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved + to the solution_path. + load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded + from the solution_path. + solver_solution_path(optional): the path to save or load the solution. + memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, + the memory budget will be infinity. + ''' + device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape) + if meta_args is None: + meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) + model = initialize_model(model, + meta_args, + device_mesh, + save_solver_solution=save_solver_solution, + load_solver_solution=load_solver_solution, + solver_solution_path=solver_solution_path, + memory_budget=memory_budget) + + return model diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py index ac5b1d983..0979d8353 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py @@ -17,6 +17,7 @@ from torch.profiler import ProfilerActivity, profile, record_function, schedule, from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP +from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec from colossalai.auto_parallel.tensor_shard.solver import ( CostGraph, @@ -80,12 +81,9 @@ def main(): model = GPT2LMHeadModel(config=config).to('cuda') global_numel = sum([p.numel() for p in model.parameters()]) - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - meta_input_sample = { - 'input_ids': input_ids.to('meta'), - 'attention_mask': attention_mask.to('meta'), + 'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), } physical_mesh_id = torch.arange(0, 4) @@ -93,39 +91,8 @@ def main(): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - shape_consistency_manager = ShapeConsistencyManager() - - tracer = ColoTracer() - - graph = tracer.trace(root=model, meta_args=meta_input_sample) - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - - graph_analyser = GraphAnalyser(gm) - liveness_list = graph_analyser.liveness_analysis() - solver_options = SolverOptions() - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - - cost_graph = CostGraph(strategies_constructor.leaf_strategies) - cost_graph.simplify_graph() - solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1) - ret = solver.call_solver_serialized_args() - - solution = list(ret[0]) - # solution = [0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 2, 13, 8, 9, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 0, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 12, 8, 8, 8, 0, 0, 20, 12, 12, 12, 6, 6, 6, 6, 2, 6, 0, 0, 4, 0, 0, 0, 4, 0, 4, 3, 3, 12, 3, 3, 8, 8, 8, 8, 8, 8, 8, 8, 3, 8, 2, 2, 11, 4, 4, 9, 0, 0, 8, 0] - print(solution) - gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( - gm, solution, device_mesh, strategies_constructor) - gm = runtime_apply_pass(gm) - gm.recompile() - # *******************strategy selected******************* - print("*******************strategy selected*******************") - strategies_list = solution - - nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] - for index, node in enumerate(nodes): - print(node.name, node.strategies_vector[strategies_list[index]].name) + + gm = initialize_model(model, meta_input_sample, device_mesh) # build criterion criterion = GPTLMLoss() @@ -146,7 +113,7 @@ def main(): input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE) optimizer.zero_grad() start = time() - outputs = gm(input_ids, attn_mask, sharding_spec_dict, origin_spec_dict, comm_actions_dict) + outputs = gm(input_ids, attn_mask) loss = criterion(outputs, input_ids) loss.backward() optimizer.step()