feat: add unitest for model (#300)

* feat: add unitest for model

* feat:add model test
pull/310/head
jiaxingli 2023-09-14 13:18:34 +08:00 committed by GitHub
parent 85e39aae67
commit 882a07011c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 528 additions and 0 deletions

View File

@ -0,0 +1,65 @@
import multiprocessing as mp
import pytest
import torch
from internlm.model.embedding import Embedding1D
from tests.test_model.test_model_internlm import build_environment, seed_all
def check_embedding(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
vocab_size = 4
hidden_size = 2
# fix seed
seed_all(1024)
# define embedding
embedding = Embedding1D(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
padding_idx=None,
)
embedding.weight.data.copy_(torch.randn(vocab_size, hidden_size))
embedding = embedding.to(device)
# create input
input_ids = torch.tensor([[0, 2], [1, 3]]).to(device)
result = embedding(input_ids)
standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]]
standard_result = torch.tensor(standard_list).to(device)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol, equal_nan=True)
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = embedding.weight.grad
standard_glist = [[-0.4461, 0.5602], [0.4353, 1.2988], [-0.0625, -1.3609], [0.9595, -0.1144]]
standard_grad = torch.tensor(standard_glist).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol, equal_nan=True)
@pytest.mark.embedding
def test_embedding():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_embedding, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_embedding.py"])

View File

@ -0,0 +1,379 @@
import multiprocessing as mp
import random
import numpy as np
import pytest
import torch
from torch import nn
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.linear import RewardModelLinear, ScaleColumnParallelLinear
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D
from internlm.model.utils import gather_forward_split_backward
config = Config(
dict(
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
model_type="INTERNLM",
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
model=dict(
checkpoint=False,
num_attention_heads=2,
embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=1024,
num_layers=2,
mlp_ratio=1,
apply_post_layer_norm=False,
dtype=torch.bfloat16,
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1,
),
resume_tb_folder="",
tensorboard_folder="",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
)
)
def build_environment(rank, world_size):
import os
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12345"
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def check_block(args):
# init
rank, world_size = args
build_environment(rank, world_size)
device = torch.device("cuda")
rtol, atol = (1e-3, 5e-3)
# fix seed
seed_all(1024)
# define block
blocks = nn.ModuleList(
[
PackedFlashBaseLayer1D(
hidden_size=4, # 768
num_attention_heads=2, # 12
mlp_ratio=2,
attn_drop_rate=0.0,
drop_rate=0.0,
dtype=torch.bfloat16,
layer_norm_epsilon=1e-5,
checkpoint=lid < 0,
layer_idx=lid + 0, # This parameter is used for caching during generation
residual_in_fp32=False,
device=device,
norm_type="rmsnorm",
dropout_selective_checkpoint=True,
use_scaled_init=True,
use_swiglu=True,
)
for lid in range(4) # 32
]
)
# create input
cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32).to(device) # [0, 8, 16]
indexes = torch.tensor([0, 1, 0, 1]).to(device) # [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
hidden_states = torch.tensor([[0, 3, 2, 1]]).to(device) # [[4, 118, 0, 1, 2, 3, 0, 1, 1, 97, 0, 0, 0, 0, 0, 0]]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
hidden_states = torch.tensor(
[
[
[-1.1620, 1.3113, 0.1507, 2.2698],
[-1.2610, 1.0990, 0.3787, -0.3478],
[1.4001, 1.1982, -0.6696, 0.3269],
[1.3304, 1.2262, 1.0735, -1.1169],
]
]
)
hidden_states = hidden_states.squeeze(0).to(device).requires_grad_()
# forward
for _, block in enumerate(blocks):
block = block.to(torch.bfloat16)
block = block.to(device)
hidden_states = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=None,
max_seqlen=max_seqlen,
)
result = hidden_states
standard_result = torch.tensor(
[
[-1.1621, 1.3111, 0.1509, 2.2697],
[-1.2611, 1.0988, 0.3787, -0.3478],
[1.4000, 1.1982, -0.6694, 0.3268],
[1.3303, 1.2262, 1.0736, -1.1169],
]
).to(device)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
hidden_states.retain_grad()
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = hidden_states.grad
standard_grad = torch.tensor(
[
[0.7999, -0.2595, 0.2649, -1.3256],
[0.7064, 0.0283, -0.5508, 0.6494],
[-1.4657, -2.0316, 1.3776, 0.7211],
[-0.6046, 0.4329, -0.1884, 1.1170],
]
).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
def check_head(args):
# init
rank, world_size, is_reward = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
hidden_size = 4
vocab_size = 4
embed_grad_scale = 1
# fix seed
seed_all(1024)
# load standard
if is_reward:
head_cls = RewardModelLinear
standard_result = torch.tensor([[3.5938], [1.0703], [3.6250], [3.6250]], dtype=torch.bfloat16).to(device)
standard_grad = torch.tensor(
[
[-0.2246, 0.0164, -0.0591, 0.1660],
[-0.5625, 0.0408, -0.1484, 0.4160],
[-0.1758, 0.0128, -0.0464, 0.1299],
[-0.4785, 0.0347, -0.1260, 0.3516],
],
dtype=torch.bfloat16,
).to(device)
else:
head_cls = ScaleColumnParallelLinear
standard_result = torch.tensor(
[
[3.5938, -2.2188, 2.0312, 3.5625],
[1.0703, -1.1797, 1.1406, 1.6641],
[3.6250, -2.0156, 1.7656, 3.4531],
[3.6250, -2.0156, 1.7656, 3.4531],
],
dtype=torch.bfloat16,
).to(device)
standard_grad = torch.tensor(
[
[-0.2354, 0.0981, -0.2930, -0.6328],
[0.2344, -0.2334, -0.0918, 0.1396],
[-0.5898, -1.0156, -0.7070, 1.3750],
[0.0242, -0.1494, 0.1206, -0.0427],
],
dtype=torch.bfloat16,
).to(device)
# define head
head = head_cls(
in_features=hidden_size,
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=torch.bfloat16,
weight_scale=embed_grad_scale,
)
head = head.to(torch.bfloat16)
head = head.to(device)
# create input
hidden_states = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
],
dtype=torch.bfloat16,
requires_grad=True,
).to(device)
# forward
result = head(hidden_states)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
hidden_states.retain_grad()
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = hidden_states.grad
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
def check_gather_forward(args):
# init
rank, world_size, parallel_tensor = args
assert parallel_tensor in [1, 2]
config.parallel.tensor = parallel_tensor
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
# fix seed
seed_all(1024)
# load standard
if parallel_tensor == 1:
standard_result = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
]
).to(device)
standard_grad = torch.tensor(
[
[-0.4461, 0.5602, -0.0625, -1.3609],
[0.4353, 1.2988, 0.9595, -0.1144],
[-0.7593, -0.4031, 0.2041, 1.4955],
[0.5706, 0.9047, -0.6965, -0.3757],
]
).to(device)
else:
standard_result = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000, 8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000, 3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000, 8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000, 8.3726, 1.2875, 5.5101, 1.0000],
]
).to(device)
if rank % 2 == 0:
standard_grad = torch.tensor(
[
[-0.4461, 0.5602, -0.0625, -1.3609],
[-0.7593, -0.4031, 0.2041, 1.4955],
[0.8093, 1.7580, 1.2996, -0.7545],
[1.0474, -0.5767, -1.0401, 0.8233],
]
).to(device)
else:
standard_grad = torch.tensor(
[
[0.4353, 1.2988, 0.9595, -0.1144],
[0.5706, 0.9047, -0.6965, -0.3757],
[-1.3589, -0.7202, 0.6094, -0.8208],
[-1.0042, 0.3695, 0.2511, -0.2718],
]
).to(device)
# create input
hidden_states = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
],
requires_grad=True,
).to(device)
# forward
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
loss = torch.randn_like(result)
hidden_states.retain_grad()
# backward
result.backward(loss)
grad = hidden_states.grad
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
@pytest.mark.block
def test_block():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_block, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
@pytest.mark.head
@pytest.mark.parametrize("is_reward", [True, False])
def test_head(is_reward):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_head, [[rank, 8, is_reward] for rank in range(8)])
pool.close()
pool.join()
@pytest.mark.gather_forward
@pytest.mark.parametrize("parallel_tensor", [1, 2])
def test_gather_forward(parallel_tensor):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_gather_forward, [[rank, 8, parallel_tensor] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_model_internlm.py"])

View File

@ -0,0 +1,84 @@
import multiprocessing as mp
import pytest
import torch
from internlm.model.utils import try_import_RMSNorm
from tests.test_model.test_model_internlm import build_environment, seed_all
RMSNorm = try_import_RMSNorm()
def check_norm(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
hidden_size = 4
layer_norm_epsilon = 1e-05
# fix seed
seed_all(1024)
# define norm
norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
norm = norm.to(device)
# create input
hidden_states = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
],
requires_grad=True,
).to(device)
# forward
result = norm(hidden_states.float())
standard = torch.tensor(
[
[1.6329, 0.3753, 1.0746, 0.1950],
[1.4288, 1.2626, 0.4268, 0.4268],
[1.6490, 0.2536, 1.0852, 0.1970],
[1.6490, 0.2536, 1.0852, 0.1970],
]
).to(device)
# check output
assert torch.allclose(result, standard, rtol=rtol, atol=atol, equal_nan=True)
hidden_states.retain_grad()
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = hidden_states.grad
standard_grad = torch.tensor(
[
[-0.0193, 0.1248, 0.0324, -0.2573],
[-0.2140, 0.2010, 0.2901, -0.1683],
[-0.0815, -0.0689, 0.0850, 0.3027],
[0.0847, 0.1739, -0.1554, -0.0773],
]
).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol, equal_nan=True)
@pytest.mark.norm
def test_norm():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_norm, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_norm.py"])