InternLM/tests/test_solver/test_optimizer.py

378 lines
12 KiB
Python

import copy
import multiprocessing as mp
import random
import numpy as np
import pytest
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import internlm
from internlm.core.context.parallel_context import Config, ParallelMode
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
config = Config(
dict(
parallel=dict(
zero1=dict(size=1, fsdp=False),
pipeline=dict(size=1, interleaved_overlap=False),
sequence_parallel=False,
tensor=1,
),
model_type="INTERNLM",
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
model=dict(
dtype=torch.bfloat16,
),
resume_tb_folder="",
tensorboard_folder="",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
grad_scaler=dict(
fp16=dict(
initial_scale=1,
min_scale=1,
growth_interval=1,
),
growth_factor=1.1,
backoff_factor=0.9,
max_scale=1,
hysteresis=1,
),
adam=dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
),
hybrid_zero_optimizer=dict(
overlap_sync_grad=False,
overlap_sync_param=False,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
),
)
)
def build_environment(rank, world_size):
import os
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12345"
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
if dtype is torch.float32:
rtol = 1.3e-6
atol = 1e-5
elif dtype is torch.bfloat16:
rtol = 2e-2
atol = 2e-2
if isinstance(a, torch.Tensor):
a = a.detach().to(dtype)
b = b.detach().to(dtype)
assert_close(a, b, rtol=rtol, atol=atol)
def init_optimizer_grouped_parameters(check_group, model):
if check_group:
optimizer_grouped_parameters = [
{
"params": list(model.parameters())[:2],
"weight_decay": config.adam.weight_decay,
"dp_mode": ParallelMode.DATA,
},
{
"params": list(model.parameters())[2:],
"weight_decay": config.adam.weight_decay,
"dp_mode": ParallelMode.DATA,
},
]
else:
optimizer_grouped_parameters = [
{
"params": model.parameters(),
"weight_decay": config.adam.weight_decay,
"dp_mode": ParallelMode.DATA,
}
]
return optimizer_grouped_parameters
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def exam_hybrid_zero_optim_with_ddp(args):
# init
rank, world_size, zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype = args
# TODO: Need to test the combine of overlap param and group_params when ready
# ParamBcastSyncHandler does not consider paramters in different optimizer group currently
if overlap_sync_param and check_group:
return
config.parallel.zero1.size = zero_parallel
config.hybrid_zero_optimizer.overlap_sync_param = overlap_sync_param
config.hybrid_zero_optimizer.overlap_sync_grad = overlap_sync_grad
config.data.micro_num = micro_num
config.model.dtype = dtype
totel_step = 5
if not overlap_sync_param:
totel_step = 1
build_environment(rank, world_size)
seed_all(1024)
# create models
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype)
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
# create optimizer
if config.hybrid_zero_optimizer.overlap_sync_param:
param_bcast_sync_handler = ParamBcastSyncHandler(zero_model)
else:
param_bcast_sync_handler = None
optimizer_grouped_parameters_zero = init_optimizer_grouped_parameters(check_group, zero_model)
optimizer_grouped_parameters_torch = init_optimizer_grouped_parameters(check_group, torch_model)
naive_optimizer = torch.optim.AdamW(
params=optimizer_grouped_parameters_zero,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
zero_optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=config.grad_scaler,
zero_cfg=config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
torch_optimizer = torch.optim.AdamW(
params=optimizer_grouped_parameters_torch,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
for _ in range(totel_step):
zero_optimizer.zero_grad()
torch_optimizer.zero_grad()
zero_optimizer.skip_grad_reduce = True
for num in range(micro_num):
if num == micro_num - 1:
zero_optimizer.skip_grad_reduce = False
seed_all(1024 + rank)
# create input
input_data = torch.rand(16, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.to(dtype))
# torch-ddp forward
torch_output = torch_model(input_data)
# check output
loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean())
# torch-ddp backward
if num == micro_num - 1:
torch_output.mean().backward()
else:
with torch_model.no_sync():
torch_output.mean().backward()
# zero-dp step
zero_optimizer.step()
# torch-ddp step
torch_optimizer.step()
# check grad
if check_group:
group1 = zip(list(torch_model.parameters())[:2], list(zero_model.parameters())[:2])
group2 = zip(list(torch_model.parameters())[2:], list(zero_model.parameters())[2:])
for torch_parm, zero_parm in group1:
if zero_parm.grad is not None:
loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype)
for torch_parm, zero_parm in group2:
if zero_parm.grad is not None:
loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype)
else:
for torch_parm, zero_parm in zip(torch_model.parameters(), zero_model.parameters()):
if zero_parm.grad is not None:
loose_close(torch_parm.grad, zero_parm.grad, dtype=dtype)
torch.cuda.synchronize()
# check updated param
if check_group:
group1 = zip(list(torch_model.parameters())[:2], list(zero_model.parameters())[:2])
group2 = zip(list(torch_model.parameters())[2:], list(zero_model.parameters())[2:])
for torch_parm, zero_parm in group1:
loose_close(torch_parm, zero_parm, dtype=dtype)
for torch_parm, zero_parm in group2:
loose_close(torch_parm, zero_parm, dtype=dtype)
else:
for torch_parm, zero_parm in zip(torch_model.parameters(), zero_model.parameters()):
loose_close(torch_parm, zero_parm, dtype=dtype)
def exam_hybrid_zero_optim_with_ckpt_load_save(args):
# init
rank, world_size, zero_parallel, check_group, dtype = args
config.parallel.zero1.size = zero_parallel
config.parallel.dtype = dtype
build_environment(rank, world_size)
# create models
zero_model = MlpModel().cuda().to(dtype)
# create optimizer
if config.hybrid_zero_optimizer.overlap_sync_param:
param_bcast_sync_handler = ParamBcastSyncHandler(zero_model)
else:
param_bcast_sync_handler = None
optimizer_grouped_parameters1 = init_optimizer_grouped_parameters(check_group, zero_model)
optimizer_grouped_parameters2 = init_optimizer_grouped_parameters(check_group, zero_model)
naive_optimizer = torch.optim.AdamW(
params=optimizer_grouped_parameters1,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
zero_optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=config.grad_scaler,
zero_cfg=config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
naive_optimizer2 = torch.optim.AdamW(
params=optimizer_grouped_parameters2,
lr=config.adam.lr,
betas=(config.adam.adam_beta1, config.adam.adam_beta2),
eps=config.adam.adam_eps,
)
zero_optimizer2 = HybridZeroOptimizer(
naive_optimizer2,
grad_scal_cfg=config.grad_scaler,
zero_cfg=config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
# save and load states
states = zero_optimizer.state_dict()
zero_optimizer2.load_state_dict(states)
# check fp32 model weights
for zero1_param, zero2_param in zip(
zero_optimizer._fp32_flat_param_groups_of_current_rank.values(),
zero_optimizer2._fp32_flat_param_groups_of_current_rank.values(),
):
assert torch.equal(zero1_param, zero2_param)
# check fp16 model weights
for zero1_param, zero2_param in zip(
zero_optimizer._fp16_param_groups.values(), zero_optimizer2._fp16_param_groups.values()
):
assert zero1_param == zero2_param
zero_parallel_check_list = [-1, 1, 4]
overlap_sync_param_check_list = [True, False]
overlap_sync_grad_check_list = [True, False]
miro_num_check_list = [1, 2, 4]
check_group_list = [True, False]
dtype_list = [torch.float32, torch.bfloat16]
@pytest.mark.parametrize("zero_parallel", zero_parallel_check_list)
@pytest.mark.parametrize("overlap_sync_param", overlap_sync_param_check_list)
@pytest.mark.parametrize("overlap_sync_grad", overlap_sync_grad_check_list)
@pytest.mark.parametrize("micro_num", miro_num_check_list)
@pytest.mark.parametrize("check_group", check_group_list)
@pytest.mark.parametrize("dtype", dtype_list)
def test_hybrid_zero_optim_with_ddp(
zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype
):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(
exam_hybrid_zero_optim_with_ddp,
[
[rank, 8, zero_parallel, overlap_sync_param, overlap_sync_grad, micro_num, check_group, dtype]
for rank in range(8)
],
)
pool.close()
pool.join()
@pytest.mark.parametrize("zero_parallel", zero_parallel_check_list)
@pytest.mark.parametrize("check_group", check_group_list)
@pytest.mark.parametrize("dtype", dtype_list)
def test_hybrid_zero_optim_with_ckpt_load_save(zero_parallel, check_group, dtype):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(
exam_hybrid_zero_optim_with_ckpt_load_save,
[[rank, 8, zero_parallel, check_group, dtype] for rank in range(8)],
)
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_optimizer.py"])