mirror of https://github.com/InternLM/InternLM
unit tests for fused precision
parent
7af65f2088
commit
cbb26d9136
|
@ -0,0 +1,128 @@
|
|||
import multiprocessing as mp
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module
|
||||
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D
|
||||
from internlm.train.utils import create_param_groups
|
||||
from tests.test_model.test_model_internlm import build_environment, seed_all
|
||||
|
||||
|
||||
def _pre_forward_hook_for_check(model, inputs): # pylint: disable=W0613
|
||||
assert all(_.dtype == torch.float32 for _ in inputs)
|
||||
|
||||
|
||||
def _post_forward_hook_for_check(model, inputs, outputs): # pylint: disable=W0613
|
||||
if isinstance(outputs, tuple):
|
||||
assert all(_.dtype == torch.half for _ in outputs)
|
||||
else:
|
||||
assert outputs.dtype == torch.half
|
||||
|
||||
|
||||
def check_fused_precision(args):
|
||||
# init
|
||||
rank, world_size = args
|
||||
device = torch.device("cuda")
|
||||
build_environment(rank, world_size)
|
||||
|
||||
# fix seed
|
||||
seed_all(1024)
|
||||
# define model
|
||||
model = PackedFlashBaseLayer1D(
|
||||
hidden_size=16, # 768
|
||||
num_attention_heads=2, # 12
|
||||
mlp_ratio=2,
|
||||
attn_drop_rate=0.0,
|
||||
drop_rate=0.0,
|
||||
dtype=torch.bfloat16,
|
||||
layer_norm_epsilon=1e-5,
|
||||
checkpoint=False,
|
||||
layer_idx=0,
|
||||
residual_in_fp32=False,
|
||||
device=device,
|
||||
norm_type="rmsnorm",
|
||||
dropout_selective_checkpoint=True,
|
||||
use_scaled_init=True,
|
||||
use_swiglu=True,
|
||||
)
|
||||
model = model.to(device)
|
||||
set_fp32_attr_to_module(model.norm1)
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=True,
|
||||
dtype=torch.half,
|
||||
sync_buffer=False,
|
||||
)
|
||||
model.model.norm1.register_forward_pre_hook(partial(_pre_forward_hook_for_check))
|
||||
model.model.norm1.register_forward_hook(partial(_post_forward_hook_for_check))
|
||||
|
||||
hidden_states = torch.rand(1, 1, 16).to(device).requires_grad_()
|
||||
|
||||
# forward
|
||||
model(hidden_states)
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(4, 1, bias=False).half()
|
||||
self.linear2 = nn.Linear(1, 4, bias=False).float()
|
||||
|
||||
|
||||
def check_split_fused_group(args):
|
||||
# init
|
||||
rank, world_size = args
|
||||
device = torch.device("cuda")
|
||||
build_environment(rank, world_size)
|
||||
rtol, atol = (1e-3, 5e-3)
|
||||
|
||||
# fix seed
|
||||
seed_all(1024)
|
||||
# define model
|
||||
model = MlpModel().to(device)
|
||||
groups = create_param_groups(model, weight_decay=0.05)
|
||||
|
||||
standard_group = (
|
||||
{
|
||||
"name": "default",
|
||||
"params": [torch.Tensor([[0.3088, 0.2935, -0.2900, 0.4280]]).to(torch.float16).to(device).requires_grad_()],
|
||||
"weight_decay": 0.05,
|
||||
},
|
||||
{
|
||||
"name": "fp32",
|
||||
"params": [torch.Tensor([[0.6273], [0.4844], [-0.0463], [-0.0090]]).to(device).requires_grad_()],
|
||||
"weight_decay": 0.05,
|
||||
},
|
||||
)
|
||||
|
||||
# check groups params
|
||||
for t1, t2 in zip(groups, standard_group):
|
||||
# assert t1["name"] == t2["name"]
|
||||
assert all(
|
||||
torch.allclose(p1, p2, rtol=rtol, atol=atol, equal_nan=True) for p1, p2 in zip(t1["params"], t2["params"])
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.fused_precision
|
||||
def test_fused_precision():
|
||||
ctx = mp.get_context("spawn")
|
||||
with ctx.Pool(processes=8) as pool:
|
||||
pool.map(check_fused_precision, [[rank, 8] for rank in range(8)])
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
@pytest.mark.split_groups
|
||||
def test_split_fused_groups():
|
||||
ctx = mp.get_context("spawn")
|
||||
with ctx.Pool(processes=8) as pool:
|
||||
pool.map(check_split_fused_group, [[rank, 8] for rank in range(8)])
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-s", "-q", "test_norm.py"])
|
Loading…
Reference in New Issue