From ab513e1ddd85347d4eece53a8acff3228453e309 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Fri, 15 Sep 2023 18:56:56 +0800 Subject: [PATCH] feat: add optimizer_unitest (#303) * feat: add optimizer_unitest * feat: add optimizer test * feat: add optimizer test * feat:add optimizer test * fianl change * feat:add optimizer test * feat:add optimizer test * feat:add optimizer test --- tests/test_solver/test_optimizer.py | 364 ++++++++++++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 tests/test_solver/test_optimizer.py diff --git a/tests/test_solver/test_optimizer.py b/tests/test_solver/test_optimizer.py new file mode 100644 index 0000000..6a22797 --- /dev/null +++ b/tests/test_solver/test_optimizer.py @@ -0,0 +1,364 @@ +import copy +import multiprocessing as mp +import random + +import numpy as np +import pytest +import torch +from torch import nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import internlm +from internlm.core.context.parallel_context import Config +from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.solver.optimizer.utils import ParamBcastSyncHandler + + +class MlpModel(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +config = Config( + dict( + parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1), + model_type="INTERNLM", + data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999), + model=dict( + dtype=torch.bfloat16, + ), + resume_tb_folder="", + tensorboard_folder="", + alert_address=None, + monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)), + grad_scaler=dict( + fp16=dict( + initial_scale=1, + min_scale=1, + growth_interval=1, + ), + growth_factor=1.1, + backoff_factor=0.9, + max_scale=1, + hysteresis=1, + ), + adam=dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, + ), + hybrid_zero_optimizer=dict( + overlap_sync_grad=False, + overlap_sync_param=False, + reduce_bucket_size=512 * 1024 * 1024, + clip_grad_norm=1.0, + ), + ) +) + + +def build_environment(rank, world_size): + import os + + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12345" + torch.cuda.empty_cache() + # launcher="torch" + internlm.launch_from_torch(config=config, seed=1024) + + +def loose_close(a, b, dtype: torch.dtype = torch.float32): + + if dtype is torch.float32: + rtol = 1.3e-6 + atol = 1e-5 + elif dtype is torch.bfloat16: + rtol = 2e-2 + atol = 2e-2 + + if isinstance(a, torch.Tensor): + a = a.detach().to(dtype) + b = b.detach().to(dtype) + + assert_close(a, b, rtol=rtol, atol=atol) + + +def init_optimizer_grouped_parameters(check_group, model): + if check_group: + optimizer_grouped_parameters = [ + { + "params": list(model.parameters())[:2], + "weight_decay": config.adam.weight_decay, + }, + { + "params": list(model.parameters())[2:], + "weight_decay": config.adam.weight_decay, + }, + ] + else: + optimizer_grouped_parameters = [{"params": model.parameters(), "weight_decay": config.adam.weight_decay}] + + return optimizer_grouped_parameters + + +def seed_all(seed, cuda_deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + + +def exam_hybrid_zero_optim_with_ddp(args): + # init + rank, world_size, zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype = args + # TODO: Need to test the combine of overlap param and group_params when ready + # ParamBcastSyncHandler does not consider paramters in different optimizer group currently + if overlap_sync_param and check_group: + return + config.parallel.zero1 = zero_parallel + config.hybrid_zero_optimizer.overlap_sync_param = overlap_sync_param + config.hybrid_zero_optimizer.overlap_sync_grad = overlap_sync_grad + config.data.micro_num = micro_num + config.model.dtype = dtype + totel_step = 5 + if not overlap_sync_param: + totel_step = 1 + + build_environment(rank, world_size) + seed_all(1024) + + # create models + torch_model = MlpModel().cuda() + zero_model = copy.deepcopy(torch_model).to(dtype) + torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() + + # create optimizer + if config.hybrid_zero_optimizer.overlap_sync_param: + param_bcast_sync_handler = ParamBcastSyncHandler(zero_model) + else: + param_bcast_sync_handler = None + + optimizer_grouped_parameters_zero = init_optimizer_grouped_parameters(check_group, zero_model) + optimizer_grouped_parameters_torch = init_optimizer_grouped_parameters(check_group, torch_model) + + naive_optimizer = torch.optim.AdamW( + params=optimizer_grouped_parameters_zero, + lr=config.adam.lr, + betas=(config.adam.adam_beta1, config.adam.adam_beta2), + eps=config.adam.adam_eps, + ) + + zero_optimizer = HybridZeroOptimizer( + naive_optimizer, + grad_scal_cfg=config.grad_scaler, + zero_cfg=config.hybrid_zero_optimizer, + param_bcast_sync_handler=param_bcast_sync_handler, + ) + + torch_optimizer = torch.optim.AdamW( + params=optimizer_grouped_parameters_torch, + lr=config.adam.lr, + betas=(config.adam.adam_beta1, config.adam.adam_beta2), + eps=config.adam.adam_eps, + ) + + for _ in range(totel_step): + zero_optimizer.zero_grad() + torch_optimizer.zero_grad() + zero_optimizer.skip_grad_reduce = True + for num in range(micro_num): + if num == micro_num - 1: + zero_optimizer.skip_grad_reduce = False + + seed_all(1024 + rank) + # create input + input_data = torch.rand(16, 128).cuda() + + # zero-dp forward + zero_output = zero_model(input_data.to(dtype)) + + # torch-ddp forward + torch_output = torch_model(input_data) + + # check output + loose_close(zero_output, torch_output, dtype=dtype) + + # zero-dp backward + zero_optimizer.backward(zero_output.mean()) + + # torch-ddp backward + if num == micro_num - 1: + torch_output.mean().backward() + else: + with torch_model.no_sync(): + torch_output.mean().backward() + + # zero-dp step + zero_optimizer.step() + + # torch-ddp step + torch_optimizer.step() + + # check grad + if check_group: + group1 = zip(list(torch_model.parameters())[:2], list(zero_model.parameters())[:2]) + group2 = zip(list(torch_model.parameters())[2:], list(zero_model.parameters())[2:]) + for torch_parm, zero_parm in group1: + if zero_parm.grad is not None: + loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype) + for torch_parm, zero_parm in group2: + if zero_parm.grad is not None: + loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype) + else: + for torch_parm, zero_parm in zip(torch_model.parameters(), zero_model.parameters()): + if zero_parm.grad is not None: + loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype) + + torch.cuda.synchronize() + # check updated param + if check_group: + group1 = zip(list(torch_model.parameters())[:2], list(zero_model.parameters())[:2]) + group2 = zip(list(torch_model.parameters())[2:], list(zero_model.parameters())[2:]) + for torch_parm, zero_parm in group1: + loose_close(torch_parm, zero_parm, dtype=dtype) + for torch_parm, zero_parm in group2: + loose_close(torch_parm, zero_parm, dtype=dtype) + else: + for torch_parm, zero_parm in zip(torch_model.parameters(), zero_model.parameters()): + loose_close(torch_parm, zero_parm, dtype=dtype) + + +def exam_hybrid_zero_optim_with_ckpt_load_save(args): + # init + rank, world_size, zero_parallel, check_group, dtype = args + config.parallel.zero1 = zero_parallel + config.parallel.dtype = dtype + + build_environment(rank, world_size) + + # create models + zero_model = MlpModel().cuda().to(dtype) + + # create optimizer + if config.hybrid_zero_optimizer.overlap_sync_param: + param_bcast_sync_handler = ParamBcastSyncHandler(zero_model) + else: + param_bcast_sync_handler = None + + optimizer_grouped_parameters1 = init_optimizer_grouped_parameters(check_group, zero_model) + optimizer_grouped_parameters2 = init_optimizer_grouped_parameters(check_group, zero_model) + + naive_optimizer = torch.optim.AdamW( + params=optimizer_grouped_parameters1, + lr=config.adam.lr, + betas=(config.adam.adam_beta1, config.adam.adam_beta2), + eps=config.adam.adam_eps, + ) + + zero_optimizer = HybridZeroOptimizer( + naive_optimizer, + grad_scal_cfg=config.grad_scaler, + zero_cfg=config.hybrid_zero_optimizer, + param_bcast_sync_handler=param_bcast_sync_handler, + ) + + naive_optimizer2 = torch.optim.AdamW( + params=optimizer_grouped_parameters2, + lr=config.adam.lr, + betas=(config.adam.adam_beta1, config.adam.adam_beta2), + eps=config.adam.adam_eps, + ) + + zero_optimizer2 = HybridZeroOptimizer( + naive_optimizer2, + grad_scal_cfg=config.grad_scaler, + zero_cfg=config.hybrid_zero_optimizer, + param_bcast_sync_handler=param_bcast_sync_handler, + ) + + # save and load states + states = zero_optimizer.state_dict() + zero_optimizer2.load_state_dict(states) + + # check fp32 model weights + for zero1_param, zero2_param in zip( + zero_optimizer._fp32_flat_param_groups_of_current_rank.values(), + zero_optimizer2._fp32_flat_param_groups_of_current_rank.values(), + ): + assert torch.equal(zero1_param, zero2_param) + + # check fp16 model weights + for zero1_param, zero2_param in zip( + zero_optimizer._fp16_param_groups.values(), zero_optimizer2._fp16_param_groups.values() + ): + assert zero1_param == zero2_param + + +zero_parallel_check_list = [-1, 1, 4] +overlap_sync_param_check_list = [True, False] +overlap_sync_grad_check_list = [True, False] +miro_num_check_list = [1, 2, 4] +check_group_list = [True, False] +dtype_list = [torch.float32, torch.bfloat16] + + +@pytest.mark.parametrize("zero_parallel", zero_parallel_check_list) +@pytest.mark.parametrize("overlap_sync_param", overlap_sync_param_check_list) +@pytest.mark.parametrize("overlap_sync_grad", overlap_sync_grad_check_list) +@pytest.mark.parametrize("micro_num", miro_num_check_list) +@pytest.mark.parametrize("check_group", check_group_list) +@pytest.mark.parametrize("dtype", dtype_list) +def test_hybrid_zero_optim_with_ddp( + zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype +): + ctx = mp.get_context("spawn") + with ctx.Pool(processes=8) as pool: + pool.map( + exam_hybrid_zero_optim_with_ddp, + [ + [rank, 8, zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype] + for rank in range(8) + ], + ) + pool.close() + pool.join() + + +@pytest.mark.parametrize("zero_parallel", zero_parallel_check_list) +@pytest.mark.parametrize("check_group", check_group_list) +@pytest.mark.parametrize("dtype", dtype_list) +def test_hybrid_zero_optim_with_ckpt_load_save(zero_parallel, check_group, dtype): + ctx = mp.get_context("spawn") + with ctx.Pool(processes=8) as pool: + pool.map( + exam_hybrid_zero_optim_with_ckpt_load_save, + [[rank, 8, zero_parallel, check_group, dtype] for rank in range(8)], + ) + pool.close() + pool.join() + + +if __name__ == "__main__": + pytest.main(["-s", "-q", "test_optimizer.py"])