2022-04-25 03:48:07 +00:00
|
|
|
import math
|
|
|
|
import time
|
2022-04-19 04:08:28 +00:00
|
|
|
import torch
|
2022-04-25 03:48:07 +00:00
|
|
|
|
|
|
|
from colossalai.utils import MultiTimer
|
|
|
|
from colossalai.context import ParallelMode, Config
|
|
|
|
from typing import List, Dict, Tuple, Callable
|
|
|
|
|
|
|
|
|
|
|
|
def get_time_stamp() -> int:
|
|
|
|
"""
|
|
|
|
Return the time stamp for profiling.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
time_stamp (int): the time given by time.time()
|
|
|
|
"""
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
time_stamp = time.time()
|
|
|
|
return time_stamp
|
|
|
|
|
|
|
|
|
|
|
|
def get_memory_states() -> Tuple[float]:
|
|
|
|
"""
|
|
|
|
Return the memory statistics.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
max_allocated (float): the allocated CUDA memory
|
|
|
|
max_cached (float): the cached CUDA memory
|
|
|
|
"""
|
|
|
|
|
|
|
|
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
|
|
|
|
max_cached = torch.cuda.max_memory_reserved() / (1024**3)
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
return max_allocated, max_cached
|
|
|
|
|
|
|
|
|
|
|
|
def find_all_configs(device_cnt: int) -> List[Dict]:
|
|
|
|
"""
|
|
|
|
Find all possible configurations for tensor parallelism
|
|
|
|
|
|
|
|
Args:
|
|
|
|
device_cnt (int): the number of devices
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
config_list (List[Dict]): a list of configurations
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _is_square(num):
|
2022-05-23 06:02:28 +00:00
|
|
|
# 2D parallel should be implemented with at least 2 devices.
|
|
|
|
if num <= 1:
|
|
|
|
return False
|
2022-04-25 03:48:07 +00:00
|
|
|
return math.floor(math.sqrt(num))**2 == num
|
|
|
|
|
|
|
|
def _is_cube(num):
|
2022-05-23 06:02:28 +00:00
|
|
|
# 3D parallel should be implemented with at least 2 devices.
|
|
|
|
if num <= 1:
|
|
|
|
return False
|
2022-04-25 03:48:07 +00:00
|
|
|
return math.floor(num**(1. / 3.))**3 == num
|
|
|
|
|
|
|
|
config_list = []
|
|
|
|
|
|
|
|
# add non-parallel config
|
|
|
|
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode=None)))
|
|
|
|
config_list.append(config)
|
|
|
|
|
|
|
|
# add 1D config
|
|
|
|
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='1d')))
|
|
|
|
config_list.append(config)
|
|
|
|
|
2022-05-23 06:02:28 +00:00
|
|
|
# add 2D config only if device_cnt is a square
|
2022-04-25 03:48:07 +00:00
|
|
|
if _is_square(device_cnt):
|
|
|
|
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2d')))
|
|
|
|
config_list.append(config)
|
|
|
|
|
|
|
|
# check for 2.5D
|
|
|
|
# iterate over depth
|
|
|
|
for depth in range(1, device_cnt):
|
|
|
|
if device_cnt % depth == 0 and _is_square(device_cnt // depth):
|
|
|
|
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='2.5d', depth=depth)))
|
|
|
|
config_list.append(config)
|
|
|
|
|
|
|
|
# check for 3D if device_cnt is a cube
|
|
|
|
if _is_cube(device_cnt):
|
|
|
|
config = dict(parallel=dict(tensor=dict(size=device_cnt, mode='3d')))
|
|
|
|
config_list.append(config)
|
|
|
|
|
|
|
|
config_list = [Config(cfg) for cfg in config_list]
|
|
|
|
return config_list
|
|
|
|
|
|
|
|
|
|
|
|
def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, data_func: Callable,
|
|
|
|
timer: MultiTimer) -> Tuple[float]:
|
|
|
|
"""
|
|
|
|
Profile the forward and backward of a model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (torch.nn.Module): a PyTorch model
|
|
|
|
warmup_steps (int): the number of steps for warmup
|
|
|
|
profile_steps (int): the number of steps for profiling
|
|
|
|
data_func (Callable): a function to generate random data
|
|
|
|
timer (colossalai.utils.Multitimer): a timer instance for time recording
|
2022-04-19 04:08:28 +00:00
|
|
|
|
2022-04-25 03:48:07 +00:00
|
|
|
Returns:
|
|
|
|
fwd_time (float): the average forward time taken by forward pass in second
|
|
|
|
bwd_time (float): the average backward time taken by forward pass in second
|
|
|
|
max_allocated (float): the maximum GPU memory allocated in GB
|
|
|
|
max_cached (float): the maximum GPU memory cached in GB
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _run_step(data):
|
|
|
|
timer.start('forward')
|
|
|
|
out = model(data)
|
|
|
|
timer.stop('forward', keep_in_history=True)
|
|
|
|
timer.start('backward')
|
|
|
|
out.mean().backward()
|
|
|
|
timer.stop('backward', keep_in_history=True)
|
|
|
|
|
|
|
|
data_list = [data_func() for _ in range(warmup_steps)]
|
|
|
|
for data in data_list:
|
|
|
|
_run_step(data)
|
|
|
|
timer.reset('forward')
|
|
|
|
timer.reset('backward')
|
|
|
|
|
|
|
|
for _ in range(profile_steps):
|
|
|
|
data = data_func()
|
|
|
|
_run_step(data)
|
|
|
|
|
|
|
|
max_allocated, max_cached = get_memory_states()
|
|
|
|
fwd_time = timer.get_timer('forward').get_history_mean()
|
|
|
|
bwd_time = timer.get_timer('backward').get_history_mean()
|
|
|
|
return fwd_time, bwd_time, max_allocated, max_cached
|
|
|
|
|
|
|
|
|
|
|
|
def get_batch_data(dim: int, batch_size: int, seq_length: int, mode: ParallelMode) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Return a random data of shape (batch_size, seq_length, dim) for profiling.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dim (int): hidden size
|
|
|
|
batch_size (int): the number of data samples
|
|
|
|
seq_length (int): the number of tokens
|
|
|
|
mode (ParallelMode): Colossal-AI ParallelMode enum
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
data (torch.Tensor): random data
|
|
|
|
"""
|
|
|
|
|
|
|
|
if mode in ['2d', '2.5d']:
|
|
|
|
batch_size = batch_size // 2
|
|
|
|
dim = dim // 2
|
|
|
|
elif mode == '3d':
|
|
|
|
batch_size = batch_size // 4
|
|
|
|
dim = dim // 2
|
|
|
|
|
|
|
|
data = torch.rand(batch_size, seq_length, dim).cuda()
|
|
|
|
return data
|