mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] integrate device mesh initialization into autoparallelize (#2393)
* [autoparallel] integrate device mesh initialization into autoparallelize * add megatron solution * update gpt autoparallel examples with latest api * adapt beta value to fit the current computation costpull/2415/head^2
parent
c72c827e95
commit
2731531bc2
|
@ -59,18 +59,6 @@ def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader,
|
|||
pass
|
||||
|
||||
|
||||
def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
|
||||
'''
|
||||
This method is used to search the best logical mesh shape for the given world size
|
||||
based on the alpha_beta_dict.
|
||||
|
||||
For example:
|
||||
if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
|
||||
'''
|
||||
# TODO: implement this function
|
||||
return (world_size, 1)
|
||||
|
||||
|
||||
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
|
||||
'''
|
||||
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
|
||||
|
@ -127,39 +115,56 @@ def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh
|
|||
|
||||
|
||||
def initialize_device_mesh(world_size: int = -1,
|
||||
physical_devices: List[int] = None,
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None):
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
logical_mesh_id: torch.Tensor = None):
|
||||
'''
|
||||
This method is used to initialize the device mesh.
|
||||
|
||||
Args:
|
||||
world_size(optional): the size of device mesh. If the world_size is -1,
|
||||
world_size: the size of device mesh. If the world_size is -1,
|
||||
the world size will be set to the number of GPUs in the current machine.
|
||||
physical_devices: the physical devices used to initialize the device mesh.
|
||||
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
|
||||
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
|
||||
generated by profile_alpha_beta function.
|
||||
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
mesh shape.
|
||||
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
||||
'''
|
||||
# if world_size is not set, use the world size from torch.distributed
|
||||
if world_size == -1:
|
||||
world_size = dist.get_world_size()
|
||||
device1d = [i for i in range(world_size)]
|
||||
|
||||
if physical_devices is None:
|
||||
physical_devices = [i for i in range(world_size)]
|
||||
physical_mesh = torch.tensor(physical_devices)
|
||||
|
||||
if alpha_beta_dict is None:
|
||||
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
|
||||
alpha_beta_dict = profile_alpha_beta(device1d)
|
||||
ab_profiler = AlphaBetaProfiler(physical_devices)
|
||||
alpha_beta_dict = ab_profiler.alpha_beta_dict
|
||||
else:
|
||||
ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)
|
||||
|
||||
if logical_mesh_shape is None:
|
||||
if logical_mesh_shape is None and logical_mesh_id is None:
|
||||
# search for the best logical mesh shape
|
||||
logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
|
||||
logical_mesh_id = ab_profiler.search_best_logical_mesh()
|
||||
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
|
||||
logical_mesh_shape = logical_mesh_id.shape
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()
|
||||
|
||||
elif logical_mesh_shape is not None and logical_mesh_id is None:
|
||||
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
|
||||
|
||||
# extract alpha and beta values for the chosen logical mesh shape
|
||||
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
|
||||
physical_mesh = torch.tensor(device1d)
|
||||
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
|
||||
mesh_shape=logical_mesh_shape,
|
||||
logical_mesh_id=logical_mesh_id,
|
||||
mesh_alpha=mesh_alpha,
|
||||
mesh_beta=mesh_beta,
|
||||
init_process_group=True)
|
||||
|
@ -224,6 +229,7 @@ def autoparallelize(model: nn.Module,
|
|||
data_process_func: callable = None,
|
||||
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
||||
logical_mesh_shape: Tuple[int] = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
save_solver_solution: bool = False,
|
||||
load_solver_solution: bool = False,
|
||||
solver_solution_path: str = None,
|
||||
|
@ -245,6 +251,7 @@ def autoparallelize(model: nn.Module,
|
|||
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
||||
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
||||
generated by search_best_logical_mesh_shape function.
|
||||
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
||||
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
||||
to the solution_path.
|
||||
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
||||
|
@ -254,7 +261,9 @@ def autoparallelize(model: nn.Module,
|
|||
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
||||
the memory budget will be infinity.
|
||||
'''
|
||||
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
|
||||
device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
|
||||
logical_mesh_shape=logical_mesh_shape,
|
||||
logical_mesh_id=logical_mesh_id)
|
||||
if meta_args is None:
|
||||
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
|
||||
|
||||
|
@ -263,7 +272,7 @@ def autoparallelize(model: nn.Module,
|
|||
device_mesh,
|
||||
save_solver_solution=save_solver_solution,
|
||||
load_solver_solution=load_solver_solution,
|
||||
solver_solution_path=solver_solution_path,
|
||||
solution_path=solver_solution_path,
|
||||
return_solution=return_solution,
|
||||
memory_budget=memory_budget)
|
||||
|
||||
|
|
|
@ -381,6 +381,8 @@ class AlphaBetaProfiler:
|
|||
first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
|
||||
second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
|
||||
mesh_alpha = [first_latency, second_latency]
|
||||
mesh_beta = [1 / first_bandwidth, 1 / second_bandwidth]
|
||||
# The beta values have been enlarged by 1e10 times temporarilly because the computation cost
|
||||
# is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
|
||||
mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]
|
||||
|
||||
return mesh_alpha, mesh_beta
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -15,7 +16,8 @@ class DeviceMesh:
|
|||
|
||||
Arguments:
|
||||
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
|
||||
mesh_shape (torch.Size): shape of logical view.
|
||||
logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
|
||||
mesh_shape (torch.Size, optional): shape of logical view.
|
||||
mesh_alpha (List[float], optional): coefficients used for computing
|
||||
communication cost (default: None)
|
||||
mesh_beta (List[float], optional): coefficients used for computing
|
||||
|
@ -28,15 +30,21 @@ class DeviceMesh:
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
physical_mesh_id,
|
||||
mesh_shape,
|
||||
mesh_alpha=None,
|
||||
mesh_beta=None,
|
||||
init_process_group=False,
|
||||
need_flatten=True):
|
||||
physical_mesh_id: torch.Tensor,
|
||||
mesh_shape: torch.Size = None,
|
||||
logical_mesh_id: torch.Tensor = None,
|
||||
mesh_alpha: List[float] = None,
|
||||
mesh_beta: List[float] = None,
|
||||
init_process_group: bool = False,
|
||||
need_flatten: bool = True):
|
||||
self.physical_mesh_id = physical_mesh_id
|
||||
self.mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
|
||||
if logical_mesh_id is None:
|
||||
self.mesh_shape = mesh_shape
|
||||
self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
|
||||
else:
|
||||
self._logical_mesh_id = logical_mesh_id
|
||||
self.mesh_shape = self._logical_mesh_id.shape
|
||||
|
||||
# map global rank into logical rank
|
||||
self.convert_map = {}
|
||||
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
|
||||
|
@ -54,8 +62,8 @@ class DeviceMesh:
|
|||
if self.need_flatten and self._logical_mesh_id.dim() > 1:
|
||||
self.flatten_device_mesh = self.flatten()
|
||||
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
|
||||
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
self.mesh_beta)
|
||||
# self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
# self.mesh_beta)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
|
|
@ -16,14 +16,14 @@ from colossalai.device.device_mesh import DeviceMesh
|
|||
from colossalai.initialize import launch_from_torch
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 128
|
||||
HIDDEN_DIM = 3072
|
||||
BATCH_SIZE = 16
|
||||
SEQ_LENGTH = 1024
|
||||
HIDDEN_DIM = 4096
|
||||
NUM_HEADS = 16
|
||||
NUM_LAYERS = 1
|
||||
NUM_LAYERS = 4
|
||||
VOCAB_SIZE = 50257
|
||||
NUM_STEPS = 10
|
||||
FP16 = False
|
||||
FP16 = True
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
|
@ -40,7 +40,7 @@ def get_mem_info(prefix=''):
|
|||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 8
|
||||
|
||||
|
||||
# Randomly Generated Data
|
||||
|
@ -66,13 +66,7 @@ def main():
|
|||
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
|
||||
}
|
||||
|
||||
# Both device mesh initialization and model initialization will be integrated into autoparallelize
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# Enable auto-parallel
|
||||
gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True)
|
||||
gm, solution = autoparallelize(model, meta_input_sample, return_solution=True)
|
||||
|
||||
# print solution on rank 0
|
||||
if gpc.get_global_rank() == 0:
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue