You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/cli/benchmark/utils.py

153 lines
4.5 KiB

import math
import time
import torch
from colossalai.utils import MultiTimer
from colossalai.context import ParallelMode, Config
from typing import List, Dict, Tuple, Callable
def get_time_stamp() -> int:
"""
Return the time stamp for profiling.
Returns:
time_stamp (int): the time given by time.time()
"""
torch.cuda.synchronize()
time_stamp = time.time()
return time_stamp
def get_memory_states() -> Tuple[float]:
"""
Return the memory statistics.
Returns:
max_allocated (float): the allocated CUDA memory
max_cached (float): the cached CUDA memory
"""
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
max_cached = torch.cuda.max_memory_reserved() / (1024**3)
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
return max_allocated, max_cached
def find_all_configs(device_cnt: int) -> List[Dict]:
"""
Find all possible configurations for tensor parallelism
Args:
device_cnt (int): the number of devices
Returns:
config_list (List[Dict]): a list of configurations
"""
def _is_square(num):
return math.floor(math.sqrt(num))**2 == num
def _is_cube(num):
return math.floor(num**(1. / 3.))**3 == num
config_list = []
# add non-parallel config
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None)))
config_list.append(config)
# add 1D config
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d')))
config_list.append(config)
# add 1D config only if device_cnt is a square
if _is_square(device_cnt):
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d')))
config_list.append(config)
# check for 2.5D
# iterate over depth
for depth in range(1, device_cnt):
if device_cnt % depth == 0 and _is_square(device_cnt // depth):
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth)))
config_list.append(config)
# check for 3D if device_cnt is a cube
if _is_cube(device_cnt):
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d')))
config_list.append(config)
config_list = [Config(cfg) for cfg in config_list]
return config_list
def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable,
timer: MultiTimer) -> Tuple[float]:
"""
Profile the forward and backward of a model
Args:
model (torch.nn.Module): a PyTorch model
warmup_steps (int): the number of steps for warmup
profile_steps (int): the number of steps for profiling
data_func (Callable): a function to generate random data
timer (colossalai.utils.Multitimer): a timer instance for time recording
Returns:
fwd_time (float): the average forward time taken by forward pass in second
bwd_time (float): the average backward time taken by forward pass in second
max_allocated (float): the maximum GPU memory allocated in GB
max_cached (float): the maximum GPU memory cached in GB
"""
def _run_step(data):
timer.start('forward')
out = model(data)
timer.stop('forward', keep_in_history=True)
timer.start('backward')
out.mean().backward()
timer.stop('backward', keep_in_history=True)
data_list = [data_func() for _ in range(warmup_steps)]
for data in data_list:
_run_step(data)
timer.reset('forward')
timer.reset('backward')
for _ in range(profile_steps):
data = data_func()
_run_step(data)
max_allocated, max_cached = get_memory_states()
fwd_time = timer.get_timer('forward').get_history_mean()
bwd_time = timer.get_timer('backward').get_history_mean()
return fwd_time, bwd_time, max_allocated, max_cached
def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor:
"""
Return a random data of shape (batch_size, seq_length, dim) for profiling.
Args:
dim (int): hidden size
batch_size (int): the number of data samples
seq_length (int): the number of tokens
mode (ParallelMode): Colossal-AI ParallelMode enum
Returns:
data (torch.Tensor): random data
"""
if mode in ['2d', '2.5d']:
batch_size = batch_size // 2
dim = dim // 2
elif mode == '3d':
batch_size = batch_size // 4
dim = dim // 2
data = torch.rand(batch_size, seq_length, dim).cuda()
return data