feat: add optimizer_unitest (#303)

* feat: add optimizer_unitest

* feat: add optimizer test

* feat: add optimizer test

* feat:add optimizer test

* fianl change

* feat:add optimizer test

* feat:add optimizer test

* feat:add optimizer test
pull/313/head^2
jiaxingli 2023-09-15 18:56:56 +08:00 committed by GitHub
parent 794a484666
commit ab513e1ddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 364 additions and 0 deletions

View File

@ -0,0 +1,364 @@
import copy
import multiprocessing as mp
import random
import numpy as np
import pytest
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import internlm
from internlm.core.context.parallel_context import Config
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
config = Config(
dict(
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
model_type="INTERNLM",
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
model=dict(
dtype=torch.bfloat16,
),
resume_tb_folder="",
tensorboard_folder="",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
grad_scaler=dict(
fp16=dict(
initial_scale=1,
min_scale=1,
growth_interval=1,
),
growth_factor=1.1,
backoff_factor=0.9,
max_scale=1,
hysteresis=1,
),
adam=dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
),
hybrid_zero_optimizer=dict(
overlap_sync_grad=False,
overlap_sync_param=False,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
),
)
)
def build_environment(rank, world_size):
import os
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12345"
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
if dtype is torch.float32:
rtol = 1.3e-6
atol = 1e-5
elif dtype is torch.bfloat16:
rtol = 2e-2
atol = 2e-2
if isinstance(a, torch.Tensor):
a = a.detach().to(dtype)
b = b.detach().to(dtype)
assert_close(a, b, rtol=rtol, atol=atol)
def init_optimizer_grouped_parameters(check_group, model):
if check_group:
optimizer_grouped_parameters = [
{
"params": list(model.parameters())[:2],
"weight_decay": config.adam.weight_decay,
},
{
"params": list(model.parameters())[2:],
"weight_decay": config.adam.weight_decay,
},
]
else:
optimizer_grouped_parameters = [{"params": model.parameters(), "weight_decay": config.adam.weight_decay}]
return optimizer_grouped_parameters
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def exam_hybrid_zero_optim_with_ddp(args):
# init
rank, world_size, zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype = args
# TODO: Need to test the combine of overlap param and group_params when ready
# ParamBcastSyncHandler does not consider paramters in different optimizer group currently
if overlap_sync_param and check_group:
return
config.parallel.zero1 = zero_parallel
config.hybrid_zero_optimizer.overlap_sync_param = overlap_sync_param
config.hybrid_zero_optimizer.overlap_sync_grad = overlap_sync_grad
config.data.micro_num = micro_num
config.model.dtype = dtype
totel_step = 5
if not overlap_sync_param:
totel_step = 1
build_environment(rank, world_size)
seed_all(1024)
# create models
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype)
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
# create optimizer
if config.hybrid_zero_optimizer.overlap_sync_param:
param_bcast_sync_handler = ParamBcastSyncHandler(zero_model)
else:
param_bcast_sync_handler = None
optimizer_grouped_parameters_zero = init_optimizer_grouped_parameters(check_group, zero_model)
optimizer_grouped_parameters_torch = init_optimizer_grouped_parameters(check_group, torch_model)
naive_optimizer = torch.optim.AdamW(
params=optimizer_grouped_parameters_zero,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
zero_optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=config.grad_scaler,
zero_cfg=config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
torch_optimizer = torch.optim.AdamW(
params=optimizer_grouped_parameters_torch,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
for _ in range(totel_step):
zero_optimizer.zero_grad()
torch_optimizer.zero_grad()
zero_optimizer.skip_grad_reduce = True
for num in range(micro_num):
if num == micro_num - 1:
zero_optimizer.skip_grad_reduce = False
seed_all(1024 + rank)
# create input
input_data = torch.rand(16, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.to(dtype))
# torch-ddp forward
torch_output = torch_model(input_data)
# check output
loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean())
# torch-ddp backward
if num == micro_num - 1:
torch_output.mean().backward()
else:
with torch_model.no_sync():
torch_output.mean().backward()
# zero-dp step
zero_optimizer.step()
# torch-ddp step
torch_optimizer.step()
# check grad
if check_group:
group1 = zip(list(torch_model.parameters())[:2], list(zero_model.parameters())[:2])
group2 = zip(list(torch_model.parameters())[2:], list(zero_model.parameters())[2:])
for torch_parm, zero_parm in group1:
if zero_parm.grad is not None:
loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype)
for torch_parm, zero_parm in group2:
if zero_parm.grad is not None:
loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype)
else:
for torch_parm, zero_parm in zip(torch_model.parameters(), zero_model.parameters()):
if zero_parm.grad is not None:
loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype)
torch.cuda.synchronize()
# check updated param
if check_group:
group1 = zip(list(torch_model.parameters())[:2], list(zero_model.parameters())[:2])
group2 = zip(list(torch_model.parameters())[2:], list(zero_model.parameters())[2:])
for torch_parm, zero_parm in group1:
loose_close(torch_parm, zero_parm, dtype=dtype)
for torch_parm, zero_parm in group2:
loose_close(torch_parm, zero_parm, dtype=dtype)
else:
for torch_parm, zero_parm in zip(torch_model.parameters(), zero_model.parameters()):
loose_close(torch_parm, zero_parm, dtype=dtype)
def exam_hybrid_zero_optim_with_ckpt_load_save(args):
# init
rank, world_size, zero_parallel, check_group, dtype = args
config.parallel.zero1 = zero_parallel
config.parallel.dtype = dtype
build_environment(rank, world_size)
# create models
zero_model = MlpModel().cuda().to(dtype)
# create optimizer
if config.hybrid_zero_optimizer.overlap_sync_param:
param_bcast_sync_handler = ParamBcastSyncHandler(zero_model)
else:
param_bcast_sync_handler = None
optimizer_grouped_parameters1 = init_optimizer_grouped_parameters(check_group, zero_model)
optimizer_grouped_parameters2 = init_optimizer_grouped_parameters(check_group, zero_model)
naive_optimizer = torch.optim.AdamW(
params=optimizer_grouped_parameters1,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
zero_optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=config.grad_scaler,
zero_cfg=config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
naive_optimizer2 = torch.optim.AdamW(
params=optimizer_grouped_parameters2,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
zero_optimizer2 = HybridZeroOptimizer(
naive_optimizer2,
grad_scal_cfg=config.grad_scaler,
zero_cfg=config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
# save and load states
states = zero_optimizer.state_dict()
zero_optimizer2.load_state_dict(states)
# check fp32 model weights
for zero1_param, zero2_param in zip(
zero_optimizer._fp32_flat_param_groups_of_current_rank.values(),
zero_optimizer2._fp32_flat_param_groups_of_current_rank.values(),
):
assert torch.equal(zero1_param, zero2_param)
# check fp16 model weights
for zero1_param, zero2_param in zip(
zero_optimizer._fp16_param_groups.values(), zero_optimizer2._fp16_param_groups.values()
):
assert zero1_param == zero2_param
zero_parallel_check_list = [-1, 1, 4]
overlap_sync_param_check_list = [True, False]
overlap_sync_grad_check_list = [True, False]
miro_num_check_list = [1, 2, 4]
check_group_list = [True, False]
dtype_list = [torch.float32, torch.bfloat16]
@pytest.mark.parametrize("zero_parallel", zero_parallel_check_list)
@pytest.mark.parametrize("overlap_sync_param", overlap_sync_param_check_list)
@pytest.mark.parametrize("overlap_sync_grad", overlap_sync_grad_check_list)
@pytest.mark.parametrize("micro_num", miro_num_check_list)
@pytest.mark.parametrize("check_group", check_group_list)
@pytest.mark.parametrize("dtype", dtype_list)
def test_hybrid_zero_optim_with_ddp(
zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype
):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(
exam_hybrid_zero_optim_with_ddp,
[
[rank, 8, zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype]
for rank in range(8)
],
)
pool.close()
pool.join()
@pytest.mark.parametrize("zero_parallel", zero_parallel_check_list)
@pytest.mark.parametrize("check_group", check_group_list)
@pytest.mark.parametrize("dtype", dtype_list)
def test_hybrid_zero_optim_with_ckpt_load_save(zero_parallel, check_group, dtype):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(
exam_hybrid_zero_optim_with_ckpt_load_save,
[[rank, 8, zero_parallel, check_group, dtype] for rank in range(8)],
)
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_optimizer.py"])