import math import time from typing import Callable, Dict, List, Tuple import torch from colossalai.context import Config, ParallelMode from colossalai.utils import MultiTimer def get_time_stamp() -> int: """ Return the time stamp for profiling. Returns: time_stamp (int): the time given by time.time() """ torch.cuda.synchronize() time_stamp = time.time() return time_stamp def get_memory_states() -> Tuple[float]: """ Return the memory statistics. Returns: max_allocated (float): the allocated CUDA memory max_cached (float): the cached CUDA memory """ max_allocated = torch.cuda.max_memory_allocated() / (1024**3) max_cached = torch.cuda.max_memory_reserved() / (1024**3) torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() return max_allocated, max_cached def find_all_configs(device_cnt: int) -> List[Dict]: """ Find all possible configurations for tensor parallelism Args: device_cnt (int): the number of devices Returns: config_list (List[Dict]): a list of configurations """ def _is_square(num): # 2D parallel should be implemented with at least 2 devices. if num <= 1: return False return math.floor(math.sqrt(num))**2 == num def _is_cube(num): # 3D parallel should be implemented with at least 2 devices. if num <= 1: return False return math.floor(num**(1. / 3.))**3 == num config_list = [] # add non-parallel config config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None))) config_list.append(config) # add 1D config config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d'))) config_list.append(config) # add 2D config only if device_cnt is a square if _is_square(device_cnt): config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d'))) config_list.append(config) # check for 2.5D # iterate over depth for depth in range(1, device_cnt): if device_cnt % depth == 0 and _is_square(device_cnt // depth): config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth))) config_list.append(config) # check for 3D if device_cnt is a cube if _is_cube(device_cnt): config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d'))) config_list.append(config) config_list = [Config(cfg) for cfg in config_list] return config_list def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable, timer: MultiTimer) -> Tuple[float]: """ Profile the forward and backward of a model Args: model (torch.nn.Module): a PyTorch model warmup_steps (int): the number of steps for warmup profile_steps (int): the number of steps for profiling data_func (Callable): a function to generate random data timer (colossalai.utils.Multitimer): a timer instance for time recording Returns: fwd_time (float): the average forward time taken by forward pass in second bwd_time (float): the average backward time taken by forward pass in second max_allocated (float): the maximum GPU memory allocated in GB max_cached (float): the maximum GPU memory cached in GB """ def _run_step(data): timer.start('forward') out = model(data) timer.stop('forward', keep_in_history=True) timer.start('backward') out.mean().backward() timer.stop('backward', keep_in_history=True) data_list = [data_func() for _ in range(warmup_steps)] for data in data_list: _run_step(data) timer.reset('forward') timer.reset('backward') for _ in range(profile_steps): data = data_func() _run_step(data) max_allocated, max_cached = get_memory_states() fwd_time = timer.get_timer('forward').get_history_mean() bwd_time = timer.get_timer('backward').get_history_mean() return fwd_time, bwd_time, max_allocated, max_cached def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor: """ Return a random data of shape (batch_size, seq_length, dim) for profiling. Args: dim (int): hidden size batch_size (int): the number of data samples seq_length (int): the number of tokens mode (ParallelMode): Colossal-AI ParallelMode enum Returns: data (torch.Tensor): random data """ if mode in ['2d', '2.5d']: batch_size = batch_size // 2 dim = dim // 2 elif mode == '3d': batch_size = batch_size // 4 dim = dim // 2 data = torch.rand(batch_size, seq_length, dim).cuda() return data