[autoparallel] integrate device mesh initialization into autoparallelize (#2393)

* [autoparallel] integrate device mesh initialization into autoparallelize

* add megatron solution

* update gpt autoparallel examples with latest api

* adapt beta value to fit the current computation cost
pull/2415/head^2
YuliangLiu0306 2023-01-11 14:03:49 +08:00 committed by GitHub
parent c72c827e95
commit 2731531bc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 64 additions and 51 deletions

View File

@ -59,18 +59,6 @@ def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader,
pass
def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
'''
This method is used to search the best logical mesh shape for the given world size
based on the alpha_beta_dict.
For example:
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
'''
# TODO: implement this function
return (world_size, 1)
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
'''
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
@ -127,39 +115,56 @@ def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh
def initialize_device_mesh(world_size: int = -1,
physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None):
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None):
'''
This method is used to initialize the device mesh.
Args:
world_size(optional): the size of device mesh. If the world_size is -1,
world_size: the size of device mesh. If the world_size is -1,
the world size will be set to the number of GPUs in the current machine.
physical_devices: the physical devices used to initialize the device mesh.
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
generated by profile_alpha_beta function.
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
mesh shape.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
'''
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
device1d = [i for i in range(world_size)]
if physical_devices is None:
physical_devices = [i for i in range(world_size)]
physical_mesh = torch.tensor(physical_devices)
if alpha_beta_dict is None:
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
alpha_beta_dict = profile_alpha_beta(device1d)
ab_profiler = AlphaBetaProfiler(physical_devices)
alpha_beta_dict = ab_profiler.alpha_beta_dict
else:
ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)
if logical_mesh_shape is None:
if logical_mesh_shape is None and logical_mesh_id is None:
# search for the best logical mesh shape
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
logical_mesh_id = ab_profiler.search_best_logical_mesh()
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
logical_mesh_shape = logical_mesh_id.shape
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()
elif logical_mesh_shape is not None and logical_mesh_id is None:
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
# extract alpha and beta values for the chosen logical mesh shape
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
physical_mesh = torch.tensor(device1d)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True)
@ -224,6 +229,7 @@ def autoparallelize(model: nn.Module,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
@ -245,6 +251,7 @@ def autoparallelize(model: nn.Module,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@ -254,7 +261,9 @@ def autoparallelize(model: nn.Module,
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
'''
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
logical_mesh_shape=logical_mesh_shape,
logical_mesh_id=logical_mesh_id)
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
@ -263,7 +272,7 @@ def autoparallelize(model: nn.Module,
device_mesh,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
solver_solution_path=solver_solution_path,
solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget)

View File

@ -381,6 +381,8 @@ class AlphaBetaProfiler:
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
mesh_alpha = [first_latency, second_latency]
mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth]
# The beta values have been enlarged by 1e10 times temporarilly because the computation cost
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]
return mesh_alpha, mesh_beta

View File

@ -1,5 +1,6 @@
import operator
from functools import reduce
from typing import List, Tuple
import torch
import torch.distributed as dist
@ -15,7 +16,8 @@ class DeviceMesh:
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
mesh_shape (torch.Size, optional): shape of logical view.
mesh_alpha (List[float], optional): coefficients used for computing
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
@ -28,15 +30,21 @@ class DeviceMesh:
"""
def __init__(self,
physical_mesh_id,
mesh_shape,
mesh_alpha=None,
mesh_beta=None,
init_process_group=False,
need_flatten=True):
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
logical_mesh_id: torch.Tensor = None,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
need_flatten: bool = True):
self.physical_mesh_id = physical_mesh_id
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
if logical_mesh_id is None:
self.mesh_shape = mesh_shape
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape
# map global rank into logical rank
self.convert_map = {}
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
@ -54,8 +62,8 @@ class DeviceMesh:
if self.need_flatten and self._logical_mesh_id.dim() > 1:
self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
self.mesh_beta)
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
# self.mesh_beta)
@property
def shape(self):

View File

@ -16,14 +16,14 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
BATCH_SIZE = 8
SEQ_LENGTH = 128
HIDDEN_DIM = 3072
BATCH_SIZE = 16
SEQ_LENGTH = 1024
HIDDEN_DIM = 4096
NUM_HEADS = 16
NUM_LAYERS = 1
NUM_LAYERS = 4
VOCAB_SIZE = 50257
NUM_STEPS = 10
FP16 = False
FP16 = True
def get_cpu_mem():
@ -40,7 +40,7 @@ def get_mem_info(prefix=''):
def get_tflops(model_numel, batch_size, seq_len, step_time):
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 8
# Randomly Generated Data
@ -66,13 +66,7 @@ def main():
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
}
# Both device mesh initialization and model initialization will be integrated into autoparallelize
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# Enable auto-parallel
gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True)
gm, solution = autoparallelize(model, meta_input_sample, return_solution=True)
# print solution on rank 0
if gpc.get_global_rank() == 0: