From 2731531bc23a93282ca5408afa3b1a329c0e331d Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 11 Jan 2023 14:03:49 +0800 Subject: [PATCH] [autoparallel] integrate device mesh initialization into autoparallelize (#2393) * [autoparallel] integrate device mesh initialization into autoparallelize * add megatron solution * update gpt autoparallel examples with latest api * adapt beta value to fit the current computation cost --- .../auto_parallel/tensor_shard/initialize.py | 61 ++++++++++-------- colossalai/device/alpha_beta_profiler.py | 4 +- colossalai/device/device_mesh.py | 30 +++++---- .../auto_parallel/auto_parallel_with_gpt.py | 20 ++---- .../saved_solution/solution_12_layers.pt | Bin 0 -> 1903 bytes .../saved_solution/solution_1_layers.pt | Bin 0 -> 559 bytes .../saved_solution/solution_4_layers.pt | Bin 0 -> 943 bytes 7 files changed, 64 insertions(+), 51 deletions(-) create mode 100644 examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt create mode 100644 examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt create mode 100644 examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 0dce2564c..8c24c0d7b 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -59,18 +59,6 @@ def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, pass -def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]): - ''' - This method is used to search the best logical mesh shape for the given world size - based on the alpha_beta_dict. - - For example: - if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1). - ''' - # TODO: implement this function - return (world_size, 1) - - def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]): ''' This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape @@ -127,39 +115,56 @@ def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh def initialize_device_mesh(world_size: int = -1, + physical_devices: List[int] = None, alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, - logical_mesh_shape: Tuple[int] = None): + logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None): ''' This method is used to initialize the device mesh. Args: - world_size(optional): the size of device mesh. If the world_size is -1, + world_size: the size of device mesh. If the world_size is -1, the world size will be set to the number of GPUs in the current machine. + physical_devices: the physical devices used to initialize the device mesh. alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be generated by profile_alpha_beta function. logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical - mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be - generated by search_best_logical_mesh_shape function. + mesh shape. + logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. ''' # if world_size is not set, use the world size from torch.distributed if world_size == -1: world_size = dist.get_world_size() - device1d = [i for i in range(world_size)] + + if physical_devices is None: + physical_devices = [i for i in range(world_size)] + physical_mesh = torch.tensor(physical_devices) if alpha_beta_dict is None: # if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device - alpha_beta_dict = profile_alpha_beta(device1d) + ab_profiler = AlphaBetaProfiler(physical_devices) + alpha_beta_dict = ab_profiler.alpha_beta_dict + else: + ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict) - if logical_mesh_shape is None: + if logical_mesh_shape is None and logical_mesh_id is None: # search for the best logical mesh shape - logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict) + logical_mesh_id = ab_profiler.search_best_logical_mesh() + logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int) + logical_mesh_shape = logical_mesh_id.shape + + # extract alpha and beta values for the chosen logical mesh shape + mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh() + + elif logical_mesh_shape is not None and logical_mesh_id is None: + logical_mesh_id = physical_mesh.reshape(logical_mesh_shape) + + # extract alpha and beta values for the chosen logical mesh shape + mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id) - # extract alpha and beta values for the chosen logical mesh shape - mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape) - physical_mesh = torch.tensor(device1d) device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, - mesh_shape=logical_mesh_shape, + logical_mesh_id=logical_mesh_id, mesh_alpha=mesh_alpha, mesh_beta=mesh_beta, init_process_group=True) @@ -224,6 +229,7 @@ def autoparallelize(model: nn.Module, data_process_func: callable = None, alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None, save_solver_solution: bool = False, load_solver_solution: bool = False, solver_solution_path: str = None, @@ -245,6 +251,7 @@ def autoparallelize(model: nn.Module, logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be generated by search_best_logical_mesh_shape function. + logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved to the solution_path. load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded @@ -254,7 +261,9 @@ def autoparallelize(model: nn.Module, memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, the memory budget will be infinity. ''' - device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape) + device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, + logical_mesh_shape=logical_mesh_shape, + logical_mesh_id=logical_mesh_id) if meta_args is None: meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) @@ -263,7 +272,7 @@ def autoparallelize(model: nn.Module, device_mesh, save_solver_solution=save_solver_solution, load_solver_solution=load_solver_solution, - solver_solution_path=solver_solution_path, + solution_path=solver_solution_path, return_solution=return_solution, memory_budget=memory_budget) diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py index 9c66cb85d..af2b10928 100644 --- a/colossalai/device/alpha_beta_profiler.py +++ b/colossalai/device/alpha_beta_profiler.py @@ -381,6 +381,8 @@ class AlphaBetaProfiler: first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group) second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group) mesh_alpha = [first_latency, second_latency] - mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth] + # The beta values have been enlarged by 1e10 times temporarilly because the computation cost + # is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future. + mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth] return mesh_alpha, mesh_beta diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 7596a100b..b5a97eded 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -1,5 +1,6 @@ import operator from functools import reduce +from typing import List, Tuple import torch import torch.distributed as dist @@ -15,7 +16,8 @@ class DeviceMesh: Arguments: physical_mesh_id (torch.Tensor): physical view of the devices in global rank. - mesh_shape (torch.Size): shape of logical view. + logical_mesh_id (torch.Tensor): logical view of the devices in global rank. + mesh_shape (torch.Size, optional): shape of logical view. mesh_alpha (List[float], optional): coefficients used for computing communication cost (default: None) mesh_beta (List[float], optional): coefficients used for computing @@ -28,15 +30,21 @@ class DeviceMesh: """ def __init__(self, - physical_mesh_id, - mesh_shape, - mesh_alpha=None, - mesh_beta=None, - init_process_group=False, - need_flatten=True): + physical_mesh_id: torch.Tensor, + mesh_shape: torch.Size = None, + logical_mesh_id: torch.Tensor = None, + mesh_alpha: List[float] = None, + mesh_beta: List[float] = None, + init_process_group: bool = False, + need_flatten: bool = True): self.physical_mesh_id = physical_mesh_id - self.mesh_shape = mesh_shape - self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) + if logical_mesh_id is None: + self.mesh_shape = mesh_shape + self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) + else: + self._logical_mesh_id = logical_mesh_id + self.mesh_shape = self._logical_mesh_id.shape + # map global rank into logical rank self.convert_map = {} self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) @@ -54,8 +62,8 @@ class DeviceMesh: if self.need_flatten and self._logical_mesh_id.dim() > 1: self.flatten_device_mesh = self.flatten() # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) - self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, - self.mesh_beta) + # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, + # self.mesh_beta) @property def shape(self): diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py index 85c8d64d7..6ceb7fd87 100644 --- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -16,14 +16,14 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch_from_torch from colossalai.logging import disable_existing_loggers, get_dist_logger -BATCH_SIZE = 8 -SEQ_LENGTH = 128 -HIDDEN_DIM = 3072 +BATCH_SIZE = 16 +SEQ_LENGTH = 1024 +HIDDEN_DIM = 4096 NUM_HEADS = 16 -NUM_LAYERS = 1 +NUM_LAYERS = 4 VOCAB_SIZE = 50257 NUM_STEPS = 10 -FP16 = False +FP16 = True def get_cpu_mem(): @@ -40,7 +40,7 @@ def get_mem_info(prefix=''): def get_tflops(model_numel, batch_size, seq_len, step_time): # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu - return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4 + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 8 # Randomly Generated Data @@ -66,13 +66,7 @@ def main(): 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), } - # Both device mesh initialization and model initialization will be integrated into autoparallelize - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - - # Enable auto-parallel - gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True) + gm, solution = autoparallelize(model, meta_input_sample, return_solution=True) # print solution on rank 0 if gpc.get_global_rank() == 0: diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt new file mode 100644 index 0000000000000000000000000000000000000000..7b8cd7edd11e6d1f605e0e9f992b6a13676ecd10 GIT binary patch literal 1903 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfho3-LBpE?19EtC zc(Zvkdvil&n1O5#AV!C5-n>9@5CoYa3gm$x8x%4FX(k|K2dm)6Fbkp+L4xeSrykYz zQM*UOd^D|%<}pzDFbW7B~ul4I6|D+jv=$b-?3#1;R F768pwDM|nU literal 0 HcmV?d00001 diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt new file mode 100644 index 0000000000000000000000000000000000000000..9b431a45baba43b9581fb5cf3d4bf39a2aaea5d6 GIT binary patch literal 559 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfho3-LBpE?19EtC zc(Zvkdvil&n1O5#AV!C5-n>9@5CoYa3gm$x8x%4FX(k|K2dm)6Fbkp+L4xc+SI_3n zkg6Bp&Can)mSgS;pg%!40H?qC8KmLVH`JX0-fV0-P(^agx^U%S_W*e?x*94I d#0X&k^|6CkXQ6x$72wUv1`=ZeLXdihS^z_{RyzOy literal 0 HcmV?d00001 diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt new file mode 100644 index 0000000000000000000000000000000000000000..79a448c1b06f1db8731d2d45f988ff0b57810b04 GIT binary patch literal 943 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfho3-LBpE?19EtC zc(Zvkdvil&n1O5#AV!C5-n>9@5CoYa3gm$x8x%4FX(k|K2dm)6Fbkp+L4xeSrykYz zQM+js=4{>!sd@q4>>LMcl}lup7#Kh}0B2g`XHbWywX)Qr;>`R!Hz#GZq=u62U>svE zkS!PIrH2A7U;yC&Z$=OWPt(XQ5CBP_0Q3}&t{d58eiWTKKwDtCp>7WFW@FQVDw1Q? qg)0ZU2grlb)livBPywJmc94)SGem+BNCkMavVnL^KnPL~Q40VIP?Na; literal 0 HcmV?d00001