mirror of https://github.com/InternLM/InternLM
380 lines
11 KiB
Python
380 lines
11 KiB
Python
import multiprocessing as mp
|
|
import random
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
|
|
import internlm
|
|
from internlm.core.context import ParallelMode
|
|
from internlm.core.context.parallel_context import Config
|
|
from internlm.core.context.parallel_context import global_context as gpc
|
|
from internlm.model.linear import RewardModelLinear, ScaleColumnParallelLinear
|
|
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D
|
|
from internlm.model.utils import gather_forward_split_backward
|
|
|
|
config = Config(
|
|
dict(
|
|
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
|
model_type="INTERNLM",
|
|
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
|
|
model=dict(
|
|
checkpoint=False,
|
|
num_attention_heads=2,
|
|
embed_split_hidden=True,
|
|
vocab_size=103168,
|
|
embed_grad_scale=1,
|
|
parallel_output=True,
|
|
hidden_size=1024,
|
|
num_layers=2,
|
|
mlp_ratio=1,
|
|
apply_post_layer_norm=False,
|
|
dtype=torch.bfloat16,
|
|
norm_type="rmsnorm",
|
|
layer_norm_epsilon=1e-5,
|
|
use_flash_attn=True,
|
|
num_chunks=1,
|
|
),
|
|
resume_tb_folder="",
|
|
tensorboard_folder="",
|
|
alert_address=None,
|
|
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
|
|
)
|
|
)
|
|
|
|
|
|
def build_environment(rank, world_size):
|
|
import os
|
|
|
|
os.environ["RANK"] = str(rank)
|
|
os.environ["LOCAL_RANK"] = str(rank)
|
|
os.environ["WORLD_SIZE"] = str(world_size)
|
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
os.environ["MASTER_PORT"] = "12345"
|
|
torch.cuda.empty_cache()
|
|
# launcher="torch"
|
|
internlm.launch_from_torch(config=config, seed=1024)
|
|
|
|
|
|
def seed_all(seed, cuda_deterministic=False):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
if cuda_deterministic: # slower, more reproducible
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
else:
|
|
torch.backends.cudnn.deterministic = False
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
def check_block(args):
|
|
# init
|
|
rank, world_size = args
|
|
build_environment(rank, world_size)
|
|
device = torch.device("cuda")
|
|
rtol, atol = (1e-3, 5e-3)
|
|
|
|
# fix seed
|
|
seed_all(1024)
|
|
|
|
# define block
|
|
blocks = nn.ModuleList(
|
|
[
|
|
PackedFlashBaseLayer1D(
|
|
hidden_size=4, # 768
|
|
num_attention_heads=2, # 12
|
|
mlp_ratio=2,
|
|
attn_drop_rate=0.0,
|
|
drop_rate=0.0,
|
|
dtype=torch.bfloat16,
|
|
layer_norm_epsilon=1e-5,
|
|
checkpoint=lid < 0,
|
|
layer_idx=lid + 0, # This parameter is used for caching during generation
|
|
residual_in_fp32=False,
|
|
device=device,
|
|
norm_type="rmsnorm",
|
|
dropout_selective_checkpoint=True,
|
|
use_scaled_init=True,
|
|
use_swiglu=True,
|
|
)
|
|
for lid in range(4) # 32
|
|
]
|
|
)
|
|
|
|
# create input
|
|
cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32).to(device) # [0, 8, 16]
|
|
indexes = torch.tensor([0, 1, 0, 1]).to(device) # [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
|
|
hidden_states = torch.tensor([[0, 3, 2, 1]]).to(device) # [[4, 118, 0, 1, 2, 3, 0, 1, 1, 97, 0, 0, 0, 0, 0, 0]]
|
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
|
|
|
hidden_states = torch.tensor(
|
|
[
|
|
[
|
|
[-1.1620, 1.3113, 0.1507, 2.2698],
|
|
[-1.2610, 1.0990, 0.3787, -0.3478],
|
|
[1.4001, 1.1982, -0.6696, 0.3269],
|
|
[1.3304, 1.2262, 1.0735, -1.1169],
|
|
]
|
|
]
|
|
)
|
|
|
|
hidden_states = hidden_states.squeeze(0).to(device).requires_grad_()
|
|
|
|
# forward
|
|
for _, block in enumerate(blocks):
|
|
block = block.to(torch.bfloat16)
|
|
block = block.to(device)
|
|
hidden_states = block(
|
|
hidden_states,
|
|
cu_seqlens=cu_seqlens,
|
|
indexes=indexes,
|
|
inference_params=None,
|
|
max_seqlen=max_seqlen,
|
|
)
|
|
|
|
result = hidden_states
|
|
standard_result = torch.tensor(
|
|
[
|
|
[-1.1621, 1.3111, 0.1509, 2.2697],
|
|
[-1.2611, 1.0988, 0.3787, -0.3478],
|
|
[1.4000, 1.1982, -0.6694, 0.3268],
|
|
[1.3303, 1.2262, 1.0736, -1.1169],
|
|
]
|
|
).to(device)
|
|
|
|
# check output
|
|
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
|
|
|
|
hidden_states.retain_grad()
|
|
loss = torch.randn_like(result)
|
|
|
|
# backward
|
|
result.backward(loss)
|
|
|
|
grad = hidden_states.grad
|
|
standard_grad = torch.tensor(
|
|
[
|
|
[0.7999, -0.2595, 0.2649, -1.3256],
|
|
[0.7064, 0.0283, -0.5508, 0.6494],
|
|
[-1.4657, -2.0316, 1.3776, 0.7211],
|
|
[-0.6046, 0.4329, -0.1884, 1.1170],
|
|
]
|
|
).to(device)
|
|
|
|
# check grad
|
|
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
|
|
|
|
|
|
def check_head(args):
|
|
# init
|
|
rank, world_size, is_reward = args
|
|
device = torch.device("cuda")
|
|
build_environment(rank, world_size)
|
|
rtol, atol = (1e-3, 5e-3)
|
|
hidden_size = 4
|
|
vocab_size = 4
|
|
embed_grad_scale = 1
|
|
|
|
# fix seed
|
|
seed_all(1024)
|
|
|
|
# load standard
|
|
if is_reward:
|
|
head_cls = RewardModelLinear
|
|
standard_result = torch.tensor([[3.5938], [1.0703], [3.6250], [3.6250]], dtype=torch.bfloat16).to(device)
|
|
standard_grad = torch.tensor(
|
|
[
|
|
[-0.2246, 0.0164, -0.0591, 0.1660],
|
|
[-0.5625, 0.0408, -0.1484, 0.4160],
|
|
[-0.1758, 0.0128, -0.0464, 0.1299],
|
|
[-0.4785, 0.0347, -0.1260, 0.3516],
|
|
],
|
|
dtype=torch.bfloat16,
|
|
).to(device)
|
|
else:
|
|
head_cls = ScaleColumnParallelLinear
|
|
standard_result = torch.tensor(
|
|
[
|
|
[3.5938, -2.2188, 2.0312, 3.5625],
|
|
[1.0703, -1.1797, 1.1406, 1.6641],
|
|
[3.6250, -2.0156, 1.7656, 3.4531],
|
|
[3.6250, -2.0156, 1.7656, 3.4531],
|
|
],
|
|
dtype=torch.bfloat16,
|
|
).to(device)
|
|
standard_grad = torch.tensor(
|
|
[
|
|
[-0.2354, 0.0981, -0.2930, -0.6328],
|
|
[0.2344, -0.2334, -0.0918, 0.1396],
|
|
[-0.5898, -1.0156, -0.7070, 1.3750],
|
|
[0.0242, -0.1494, 0.1206, -0.0427],
|
|
],
|
|
dtype=torch.bfloat16,
|
|
).to(device)
|
|
|
|
# define head
|
|
head = head_cls(
|
|
in_features=hidden_size,
|
|
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
|
bias=False,
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
weight_scale=embed_grad_scale,
|
|
)
|
|
|
|
head = head.to(torch.bfloat16)
|
|
head = head.to(device)
|
|
|
|
# create input
|
|
hidden_states = torch.tensor(
|
|
[
|
|
[8.3726, 1.9245, 5.5101, 1.0000],
|
|
[3.3474, 2.9582, 1.0000, 1.0000],
|
|
[8.3726, 1.2875, 5.5101, 1.0000],
|
|
[8.3726, 1.2875, 5.5101, 1.0000],
|
|
],
|
|
dtype=torch.bfloat16,
|
|
requires_grad=True,
|
|
).to(device)
|
|
|
|
# forward
|
|
result = head(hidden_states)
|
|
|
|
# check output
|
|
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
|
|
|
|
hidden_states.retain_grad()
|
|
loss = torch.randn_like(result)
|
|
|
|
# backward
|
|
result.backward(loss)
|
|
grad = hidden_states.grad
|
|
|
|
# check grad
|
|
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
|
|
|
|
|
|
def check_gather_forward(args):
|
|
# init
|
|
rank, world_size, parallel_tensor = args
|
|
assert parallel_tensor in [1, 2]
|
|
config.parallel.tensor = parallel_tensor
|
|
device = torch.device("cuda")
|
|
build_environment(rank, world_size)
|
|
rtol, atol = (1e-3, 5e-3)
|
|
|
|
# fix seed
|
|
seed_all(1024)
|
|
|
|
# load standard
|
|
if parallel_tensor == 1:
|
|
standard_result = torch.tensor(
|
|
[
|
|
[8.3726, 1.9245, 5.5101, 1.0000],
|
|
[3.3474, 2.9582, 1.0000, 1.0000],
|
|
[8.3726, 1.2875, 5.5101, 1.0000],
|
|
[8.3726, 1.2875, 5.5101, 1.0000],
|
|
]
|
|
).to(device)
|
|
standard_grad = torch.tensor(
|
|
[
|
|
[-0.4461, 0.5602, -0.0625, -1.3609],
|
|
[0.4353, 1.2988, 0.9595, -0.1144],
|
|
[-0.7593, -0.4031, 0.2041, 1.4955],
|
|
[0.5706, 0.9047, -0.6965, -0.3757],
|
|
]
|
|
).to(device)
|
|
else:
|
|
standard_result = torch.tensor(
|
|
[
|
|
[8.3726, 1.9245, 5.5101, 1.0000, 8.3726, 1.9245, 5.5101, 1.0000],
|
|
[3.3474, 2.9582, 1.0000, 1.0000, 3.3474, 2.9582, 1.0000, 1.0000],
|
|
[8.3726, 1.2875, 5.5101, 1.0000, 8.3726, 1.2875, 5.5101, 1.0000],
|
|
[8.3726, 1.2875, 5.5101, 1.0000, 8.3726, 1.2875, 5.5101, 1.0000],
|
|
]
|
|
).to(device)
|
|
if rank % 2 == 0:
|
|
standard_grad = torch.tensor(
|
|
[
|
|
[-0.4461, 0.5602, -0.0625, -1.3609],
|
|
[-0.7593, -0.4031, 0.2041, 1.4955],
|
|
[0.8093, 1.7580, 1.2996, -0.7545],
|
|
[1.0474, -0.5767, -1.0401, 0.8233],
|
|
]
|
|
).to(device)
|
|
else:
|
|
standard_grad = torch.tensor(
|
|
[
|
|
[0.4353, 1.2988, 0.9595, -0.1144],
|
|
[0.5706, 0.9047, -0.6965, -0.3757],
|
|
[-1.3589, -0.7202, 0.6094, -0.8208],
|
|
[-1.0042, 0.3695, 0.2511, -0.2718],
|
|
]
|
|
).to(device)
|
|
|
|
# create input
|
|
hidden_states = torch.tensor(
|
|
[
|
|
[8.3726, 1.9245, 5.5101, 1.0000],
|
|
[3.3474, 2.9582, 1.0000, 1.0000],
|
|
[8.3726, 1.2875, 5.5101, 1.0000],
|
|
[8.3726, 1.2875, 5.5101, 1.0000],
|
|
],
|
|
requires_grad=True,
|
|
).to(device)
|
|
|
|
# forward
|
|
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
|
|
|
# check output
|
|
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
|
|
|
|
loss = torch.randn_like(result)
|
|
hidden_states.retain_grad()
|
|
|
|
# backward
|
|
result.backward(loss)
|
|
grad = hidden_states.grad
|
|
|
|
# check grad
|
|
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
|
|
|
|
|
|
@pytest.mark.block
|
|
def test_block():
|
|
ctx = mp.get_context("spawn")
|
|
with ctx.Pool(processes=8) as pool:
|
|
pool.map(check_block, [[rank, 8] for rank in range(8)])
|
|
pool.close()
|
|
pool.join()
|
|
|
|
|
|
@pytest.mark.head
|
|
@pytest.mark.parametrize("is_reward", [True, False])
|
|
def test_head(is_reward):
|
|
ctx = mp.get_context("spawn")
|
|
with ctx.Pool(processes=8) as pool:
|
|
pool.map(check_head, [[rank, 8, is_reward] for rank in range(8)])
|
|
pool.close()
|
|
pool.join()
|
|
|
|
|
|
@pytest.mark.gather_forward
|
|
@pytest.mark.parametrize("parallel_tensor", [1, 2])
|
|
def test_gather_forward(parallel_tensor):
|
|
ctx = mp.get_context("spawn")
|
|
with ctx.Pool(processes=8) as pool:
|
|
pool.map(check_gather_forward, [[rank, 8, parallel_tensor] for rank in range(8)])
|
|
pool.close()
|
|
pool.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main(["-s", "-q", "test_model_internlm.py"])
|