From af185b5519f768a8eb1a55b9bb610ff940e63089 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 18 Mar 2022 16:28:16 +0800 Subject: [PATCH] [test] fixed amp convergence comparison test (#454) --- tests/test_amp/test_naive_fp16.py | 46 +++++++++++-------- .../test_shard_model_v2.py | 2 +- .../test_sharded_optim_v2.py | 13 ++++-- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py index c6805ad51..c3554f8ca 100644 --- a/tests/test_amp/test_naive_fp16.py +++ b/tests/test_amp/test_naive_fp16.py @@ -3,7 +3,7 @@ import colossalai import copy import pytest import torch.multiprocessing as mp -from colossalai.amp import convert_to_naive_amp +from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.testing import assert_close_loose from colossalai.utils import free_port @@ -23,23 +23,29 @@ def run_naive_amp(): and fp32 torch optimizer """ + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + # create layer - test_models = ['repeated_computed_layers', 'nested_model'] + test_models = ['repeated_computed_layers', 'nested_model', 'resnet18'] for test_name in test_models: get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, train_dataloader, _, optim_class, _ = get_component_func() # create model - amp_model = model_builder(checkpoint=True).cuda() - torch_model = copy.deepcopy(amp_model) + naive_amp_model = model_builder(checkpoint=True).cuda() + apex_amp_model = copy.deepcopy(naive_amp_model) # create optimizer - amp_optimizer = optim_class(amp_model.parameters(), lr=1e-3) - torch_optimizer = optim_class(torch_model.parameters(), lr=1e-3) + naive_amp_optimizer = optim_class(naive_amp_model.parameters(), lr=1e-3) + apex_amp_optimizer = optim_class(apex_amp_model.parameters(), lr=1e-3) - # inject naive amp - amp_config = dict(initial_scale=1) - amp_model, amp_optimizer = convert_to_naive_amp(amp_model, amp_optimizer, amp_config) + # inject naive and apex amp + naive_amp_config = dict(initial_scale=128) + naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer, + naive_amp_config) + apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False) + apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data data_iter = iter(train_dataloader) @@ -47,25 +53,25 @@ def run_naive_amp(): data = data.cuda() # forward pass - amp_output = amp_model(data) - torch_output = torch_model(data) - assert_close_loose(amp_output, torch_output) + naive_amp_output = naive_amp_model(data) + apex_amp_output = apex_amp_model(data) + assert_close_loose(naive_amp_output, apex_amp_output) # backward - amp_optimizer.backward(amp_output.mean()) - torch_output.mean().backward() + naive_amp_optimizer.backward(naive_amp_output.mean()) + apex_amp_optimizer.backward(apex_amp_output.mean()) # check grad - for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()): - assert_close_loose(amp_param.grad, torch_param.grad.half()) + for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(naive_amp_param.grad, apex_amp_param.grad) # step - amp_optimizer.step() - torch_optimizer.step() + naive_amp_optimizer.step() + apex_amp_optimizer.step() # check updated param - for amp_param, torch_param in zip(amp_model.parameters(), torch_model.parameters()): - assert_close_loose(amp_param, torch_param.half()) + for naive_amp_param, apex_amp_param in zip(naive_amp_model.parameters(), apex_amp_model.parameters()): + assert_close_loose(naive_amp_param, apex_amp_param) def run_dist(rank, world_size, port): diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index cab8de7d6..b22c2d86d 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -46,7 +46,7 @@ def run_model_test(enable_autocast, shard_strategy_class): model = DDP(model) for i, (data, label) in enumerate(train_dataloader): - if i > 3: + if i > 5: break data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 7382d879a..3f3149400 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -18,6 +18,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP from common import CONFIG, check_sharded_params_padding +from colossalai.amp import convert_to_apex_amp def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): @@ -65,8 +66,6 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam): model = model_builder(checkpoint=True).half() col_model_deepcopy(zero_model, model) model = model.cuda().float() - if dist.get_world_size() > 1: - model = DDP(model) if use_cpuadam: optimizer_class = CPUAdam @@ -74,12 +73,16 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam): sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, cpu_offload=cpu_offload, initial_scale=2**5) + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) + if dist.get_world_size() > 1: + apex_model = DDP(apex_model) + for i, (data, label) in enumerate(train_dataloader): - # FIXME() if i > 5, the unittest will fail - if i > 3: + if i > 5: break data, label = data.cuda(), label.cuda() - _run_step(model, optim, data, label, criterion, False) + _run_step(apex_model, apex_optimizer, data, label, criterion, False) _run_step(zero_model, sharded_optim, data, label, criterion, False) check_sharded_params_padding(model, zero_model, loose=True) for param in model.parameters():