mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
5.3 KiB
147 lines
5.3 KiB
import time |
|
|
|
import pytest |
|
import torch |
|
from torch.utils._pytree import tree_map |
|
|
|
import colossalai |
|
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer |
|
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize |
|
from colossalai.auto_parallel.offload.solver import NOT_NVML |
|
from colossalai.fx.profiler import parameter_size |
|
from colossalai.nn.optimizer import HybridAdam |
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
|
from colossalai.utils import get_current_device |
|
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper |
|
from tests.test_auto_parallel.test_offload.model_utils import * |
|
from tests.test_tensor.common_utils import set_seed |
|
|
|
|
|
@parameterize('model_name', ['gpt2_']) |
|
@parameterize('memory_budget', [5000]) |
|
@parameterize('solver_name', ['asyn']) |
|
def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): |
|
|
|
# build model |
|
get_components_func = non_distributed_component_funcs.get_callable(model_name) |
|
model_builder, data_gen = get_components_func() |
|
label = torch.randint(low=0, high=128, size=( |
|
64, |
|
8, |
|
), device=get_current_device()) |
|
criterion = LMLoss() |
|
|
|
set_seed(42) |
|
start_time = time.time() |
|
model = model_builder() |
|
model.train() |
|
param_size = parameter_size(model) / 1024**2 / 2 |
|
init_time = time.time() - start_time |
|
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") |
|
|
|
data_args = data_gen(device="cpu") |
|
wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x |
|
data_args = tree_map(wrap_fn, data_args) |
|
start_time = time.time() |
|
model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name) |
|
solver_time = time.time() - start_time |
|
print(f"solver_time={solver_time:.3f} s") |
|
|
|
hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) |
|
optim = AMPOptimizer(hybrid_optimizer, model) |
|
|
|
with ColoInitContext(device=torch.device('cpu')): |
|
gemini_model = model_builder() |
|
gemini_model.train() |
|
|
|
hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) |
|
gemini_config = dict(strict_ddp_mode=False, |
|
device=torch.device('cpu'), |
|
placement_policy='cpu', |
|
pin_memory=True, |
|
hidden_dim=8192, |
|
search_range_mb=128) |
|
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config) |
|
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) |
|
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config) |
|
|
|
torch.cuda.empty_cache() |
|
torch.cuda.synchronize() |
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
# test gemini |
|
time_list = [] |
|
set_seed(42) |
|
data_args = data_gen(device="cuda") |
|
for step in range(10): |
|
gemini_optim.zero_grad() |
|
torch.cuda.synchronize() |
|
start_time = time.time() |
|
gemini_out = gemini_model(**data_args) |
|
gemini_loss = criterion(gemini_out, label) |
|
gemini_optim.backward(gemini_loss) |
|
torch.cuda.synchronize() |
|
time_list.append(time.time() - start_time) |
|
gemini_optim.step() |
|
|
|
torch.cuda.synchronize() |
|
|
|
exec_time = sum(sorted(time_list)[:5]) / 5 |
|
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 |
|
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 |
|
print(f'gemini | model_name: {model_name}') |
|
print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' |
|
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') |
|
print(time_list) |
|
|
|
del data_args |
|
del gemini_model |
|
del gemini_optim |
|
del gemini_out |
|
del gemini_loss |
|
|
|
# test asyn offload |
|
torch.cuda.empty_cache() |
|
torch.cuda.synchronize() |
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
time_list = [] |
|
set_seed(42) |
|
data_args = data_gen(device="cuda") |
|
data_args = tree_map(wrap_fn, data_args) |
|
for step in range(10): |
|
optim.zero_grad() |
|
torch.cuda.synchronize() |
|
start_time = time.time() |
|
loss = criterion(model(**data_args), label) |
|
optim.backward(loss) |
|
torch.cuda.synchronize() |
|
time_list.append(time.time() - start_time) |
|
optim.step() |
|
|
|
torch.cuda.synchronize() |
|
|
|
exec_time = sum(sorted(time_list)[:5]) / 5 |
|
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 |
|
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 |
|
print(f'solver_name: {solver_name} | model_name: {model_name}') |
|
print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' |
|
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') |
|
print(time_list) |
|
|
|
|
|
def run_dist(rank, world_size, port): |
|
config = {} |
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
|
exam_fwd_bwd() |
|
|
|
|
|
@pytest.mark.skip("this test failed") |
|
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') |
|
@rerun_if_address_is_in_use() |
|
def test_perf(): |
|
spawn(run_dist, 1) |
|
|
|
|
|
if __name__ == '__main__': |
|
test_perf()
|
|
|