unit tests for fused precision

pull/319/head
Wenwen Qu 2023-09-25 14:50:54 +08:00
parent 7af65f2088
commit cbb26d9136
1 changed files with 128 additions and 0 deletions

View File

@ -0,0 +1,128 @@
import multiprocessing as mp
from functools import partial
import pytest
import torch
from torch import nn
from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D
from internlm.train.utils import create_param_groups
from tests.test_model.test_model_internlm import build_environment, seed_all
def _pre_forward_hook_for_check(model, inputs): # pylint: disable=W0613
assert all(_.dtype == torch.float32 for _ in inputs)
def _post_forward_hook_for_check(model, inputs, outputs): # pylint: disable=W0613
if isinstance(outputs, tuple):
assert all(_.dtype == torch.half for _ in outputs)
else:
assert outputs.dtype == torch.half
def check_fused_precision(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
# fix seed
seed_all(1024)
# define model
model = PackedFlashBaseLayer1D(
hidden_size=16, # 768
num_attention_heads=2, # 12
mlp_ratio=2,
attn_drop_rate=0.0,
drop_rate=0.0,
dtype=torch.bfloat16,
layer_norm_epsilon=1e-5,
checkpoint=False,
layer_idx=0,
residual_in_fp32=False,
device=device,
norm_type="rmsnorm",
dropout_selective_checkpoint=True,
use_scaled_init=True,
use_swiglu=True,
)
model = model.to(device)
set_fp32_attr_to_module(model.norm1)
model = NaiveAMPModel(
model=model,
output_to_fp32=True,
dtype=torch.half,
sync_buffer=False,
)
model.model.norm1.register_forward_pre_hook(partial(_pre_forward_hook_for_check))
model.model.norm1.register_forward_hook(partial(_post_forward_hook_for_check))
hidden_states = torch.rand(1, 1, 16).to(device).requires_grad_()
# forward
model(hidden_states)
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(4, 1, bias=False).half()
self.linear2 = nn.Linear(1, 4, bias=False).float()
def check_split_fused_group(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
# fix seed
seed_all(1024)
# define model
model = MlpModel().to(device)
groups = create_param_groups(model, weight_decay=0.05)
standard_group = (
{
"name": "default",
"params": [torch.Tensor([[0.3088, 0.2935, -0.2900, 0.4280]]).to(torch.float16).to(device).requires_grad_()],
"weight_decay": 0.05,
},
{
"name": "fp32",
"params": [torch.Tensor([[0.6273], [0.4844], [-0.0463], [-0.0090]]).to(device).requires_grad_()],
"weight_decay": 0.05,
},
)
# check groups params
for t1, t2 in zip(groups, standard_group):
# assert t1["name"] == t2["name"]
assert all(
torch.allclose(p1, p2, rtol=rtol, atol=atol, equal_nan=True) for p1, p2 in zip(t1["params"], t2["params"])
)
@pytest.mark.fused_precision
def test_fused_precision():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_fused_precision, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
@pytest.mark.split_groups
def test_split_fused_groups():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_split_fused_group, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_norm.py"])