Merge remote-tracking branch 'upstream/develop' into develop

pull/304/head
huangting4201 2023-09-14 16:33:44 +08:00
commit a601d5aa09
6 changed files with 583 additions and 19 deletions

View File

@ -127,7 +127,7 @@ model = dict(
num_layers=NUM_LAYER, num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO, mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False, apply_post_layer_norm=False,
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm", norm_type="rmsnorm",
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
use_flash_attn=True, use_flash_attn=True,

View File

@ -0,0 +1,65 @@
import multiprocessing as mp
import pytest
import torch
from internlm.model.embedding import Embedding1D
from tests.test_model.test_model_internlm import build_environment, seed_all
def check_embedding(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
vocab_size = 4
hidden_size = 2
# fix seed
seed_all(1024)
# define embedding
embedding = Embedding1D(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
padding_idx=None,
)
embedding.weight.data.copy_(torch.randn(vocab_size, hidden_size))
embedding = embedding.to(device)
# create input
input_ids = torch.tensor([[0, 2], [1, 3]]).to(device)
result = embedding(input_ids)
standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]]
standard_result = torch.tensor(standard_list).to(device)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol, equal_nan=True)
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = embedding.weight.grad
standard_glist = [[-0.4461, 0.5602], [0.4353, 1.2988], [-0.0625, -1.3609], [0.9595, -0.1144]]
standard_grad = torch.tensor(standard_glist).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol, equal_nan=True)
@pytest.mark.embedding
def test_embedding():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_embedding, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_embedding.py"])

View File

@ -0,0 +1,379 @@
import multiprocessing as mp
import random
import numpy as np
import pytest
import torch
from torch import nn
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import Config
from internlm.core.context.parallel_context import global_context as gpc
from internlm.model.linear import RewardModelLinear, ScaleColumnParallelLinear
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D
from internlm.model.utils import gather_forward_split_backward
config = Config(
dict(
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
model_type="INTERNLM",
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
model=dict(
checkpoint=False,
num_attention_heads=2,
embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=1024,
num_layers=2,
mlp_ratio=1,
apply_post_layer_norm=False,
dtype=torch.bfloat16,
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1,
),
resume_tb_folder="",
tensorboard_folder="",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
)
)
def build_environment(rank, world_size):
import os
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12345"
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def check_block(args):
# init
rank, world_size = args
build_environment(rank, world_size)
device = torch.device("cuda")
rtol, atol = (1e-3, 5e-3)
# fix seed
seed_all(1024)
# define block
blocks = nn.ModuleList(
[
PackedFlashBaseLayer1D(
hidden_size=4, # 768
num_attention_heads=2, # 12
mlp_ratio=2,
attn_drop_rate=0.0,
drop_rate=0.0,
dtype=torch.bfloat16,
layer_norm_epsilon=1e-5,
checkpoint=lid < 0,
layer_idx=lid + 0, # This parameter is used for caching during generation
residual_in_fp32=False,
device=device,
norm_type="rmsnorm",
dropout_selective_checkpoint=True,
use_scaled_init=True,
use_swiglu=True,
)
for lid in range(4) # 32
]
)
# create input
cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32).to(device) # [0, 8, 16]
indexes = torch.tensor([0, 1, 0, 1]).to(device) # [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
hidden_states = torch.tensor([[0, 3, 2, 1]]).to(device) # [[4, 118, 0, 1, 2, 3, 0, 1, 1, 97, 0, 0, 0, 0, 0, 0]]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
hidden_states = torch.tensor(
[
[
[-1.1620, 1.3113, 0.1507, 2.2698],
[-1.2610, 1.0990, 0.3787, -0.3478],
[1.4001, 1.1982, -0.6696, 0.3269],
[1.3304, 1.2262, 1.0735, -1.1169],
]
]
)
hidden_states = hidden_states.squeeze(0).to(device).requires_grad_()
# forward
for _, block in enumerate(blocks):
block = block.to(torch.bfloat16)
block = block.to(device)
hidden_states = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=None,
max_seqlen=max_seqlen,
)
result = hidden_states
standard_result = torch.tensor(
[
[-1.1621, 1.3111, 0.1509, 2.2697],
[-1.2611, 1.0988, 0.3787, -0.3478],
[1.4000, 1.1982, -0.6694, 0.3268],
[1.3303, 1.2262, 1.0736, -1.1169],
]
).to(device)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
hidden_states.retain_grad()
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = hidden_states.grad
standard_grad = torch.tensor(
[
[0.7999, -0.2595, 0.2649, -1.3256],
[0.7064, 0.0283, -0.5508, 0.6494],
[-1.4657, -2.0316, 1.3776, 0.7211],
[-0.6046, 0.4329, -0.1884, 1.1170],
]
).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
def check_head(args):
# init
rank, world_size, is_reward = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
hidden_size = 4
vocab_size = 4
embed_grad_scale = 1
# fix seed
seed_all(1024)
# load standard
if is_reward:
head_cls = RewardModelLinear
standard_result = torch.tensor([[3.5938], [1.0703], [3.6250], [3.6250]], dtype=torch.bfloat16).to(device)
standard_grad = torch.tensor(
[
[-0.2246, 0.0164, -0.0591, 0.1660],
[-0.5625, 0.0408, -0.1484, 0.4160],
[-0.1758, 0.0128, -0.0464, 0.1299],
[-0.4785, 0.0347, -0.1260, 0.3516],
],
dtype=torch.bfloat16,
).to(device)
else:
head_cls = ScaleColumnParallelLinear
standard_result = torch.tensor(
[
[3.5938, -2.2188, 2.0312, 3.5625],
[1.0703, -1.1797, 1.1406, 1.6641],
[3.6250, -2.0156, 1.7656, 3.4531],
[3.6250, -2.0156, 1.7656, 3.4531],
],
dtype=torch.bfloat16,
).to(device)
standard_grad = torch.tensor(
[
[-0.2354, 0.0981, -0.2930, -0.6328],
[0.2344, -0.2334, -0.0918, 0.1396],
[-0.5898, -1.0156, -0.7070, 1.3750],
[0.0242, -0.1494, 0.1206, -0.0427],
],
dtype=torch.bfloat16,
).to(device)
# define head
head = head_cls(
in_features=hidden_size,
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=torch.bfloat16,
weight_scale=embed_grad_scale,
)
head = head.to(torch.bfloat16)
head = head.to(device)
# create input
hidden_states = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
],
dtype=torch.bfloat16,
requires_grad=True,
).to(device)
# forward
result = head(hidden_states)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
hidden_states.retain_grad()
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = hidden_states.grad
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
def check_gather_forward(args):
# init
rank, world_size, parallel_tensor = args
assert parallel_tensor in [1, 2]
config.parallel.tensor = parallel_tensor
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
# fix seed
seed_all(1024)
# load standard
if parallel_tensor == 1:
standard_result = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
]
).to(device)
standard_grad = torch.tensor(
[
[-0.4461, 0.5602, -0.0625, -1.3609],
[0.4353, 1.2988, 0.9595, -0.1144],
[-0.7593, -0.4031, 0.2041, 1.4955],
[0.5706, 0.9047, -0.6965, -0.3757],
]
).to(device)
else:
standard_result = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000, 8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000, 3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000, 8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000, 8.3726, 1.2875, 5.5101, 1.0000],
]
).to(device)
if rank % 2 == 0:
standard_grad = torch.tensor(
[
[-0.4461, 0.5602, -0.0625, -1.3609],
[-0.7593, -0.4031, 0.2041, 1.4955],
[0.8093, 1.7580, 1.2996, -0.7545],
[1.0474, -0.5767, -1.0401, 0.8233],
]
).to(device)
else:
standard_grad = torch.tensor(
[
[0.4353, 1.2988, 0.9595, -0.1144],
[0.5706, 0.9047, -0.6965, -0.3757],
[-1.3589, -0.7202, 0.6094, -0.8208],
[-1.0042, 0.3695, 0.2511, -0.2718],
]
).to(device)
# create input
hidden_states = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
],
requires_grad=True,
).to(device)
# forward
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
loss = torch.randn_like(result)
hidden_states.retain_grad()
# backward
result.backward(loss)
grad = hidden_states.grad
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol)
@pytest.mark.block
def test_block():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_block, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
@pytest.mark.head
@pytest.mark.parametrize("is_reward", [True, False])
def test_head(is_reward):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_head, [[rank, 8, is_reward] for rank in range(8)])
pool.close()
pool.join()
@pytest.mark.gather_forward
@pytest.mark.parametrize("parallel_tensor", [1, 2])
def test_gather_forward(parallel_tensor):
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_gather_forward, [[rank, 8, parallel_tensor] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_model_internlm.py"])

View File

@ -0,0 +1,84 @@
import multiprocessing as mp
import pytest
import torch
from internlm.model.utils import try_import_RMSNorm
from tests.test_model.test_model_internlm import build_environment, seed_all
RMSNorm = try_import_RMSNorm()
def check_norm(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
hidden_size = 4
layer_norm_epsilon = 1e-05
# fix seed
seed_all(1024)
# define norm
norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
norm = norm.to(device)
# create input
hidden_states = torch.tensor(
[
[8.3726, 1.9245, 5.5101, 1.0000],
[3.3474, 2.9582, 1.0000, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
[8.3726, 1.2875, 5.5101, 1.0000],
],
requires_grad=True,
).to(device)
# forward
result = norm(hidden_states.float())
standard = torch.tensor(
[
[1.6329, 0.3753, 1.0746, 0.1950],
[1.4288, 1.2626, 0.4268, 0.4268],
[1.6490, 0.2536, 1.0852, 0.1970],
[1.6490, 0.2536, 1.0852, 0.1970],
]
).to(device)
# check output
assert torch.allclose(result, standard, rtol=rtol, atol=atol, equal_nan=True)
hidden_states.retain_grad()
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = hidden_states.grad
standard_grad = torch.tensor(
[
[-0.0193, 0.1248, 0.0324, -0.2573],
[-0.2140, 0.2010, 0.2901, -0.1683],
[-0.0815, -0.0689, 0.0850, 0.3027],
[0.0847, 0.1739, -0.1554, -0.0773],
]
).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol, equal_nan=True)
@pytest.mark.norm
def test_norm():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_norm, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_norm.py"])

View File

@ -38,7 +38,7 @@ def convert2hf(model_config, states_tp_pps):
current_states["lm_head.weight"] = states.pop("head.weight") current_states["lm_head.weight"] = states.pop("head.weight")
for i in range(model_config["num_layers"]): for i in range(model_config["num_layers"]):
states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq") states.pop(f"blocks.{i}.mixer.rotary_emb.inv_freq", None)
wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape( wqkv = states.pop(f"blocks.{i}.mixer.Wqkv.weight").reshape(
3, model_config["num_attention_heads"], -1, model_config["hidden_size"] 3, model_config["num_attention_heads"], -1, model_config["hidden_size"]

View File

@ -20,6 +20,7 @@
""" PyTorch InternLM model.""" """ PyTorch InternLM model."""
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import threading, queue
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -810,23 +811,47 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.8, top_p: float = 0.8,
**kwargs): **kwargs):
"""
Return a generator in format: (response, history)
Eg.
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
"""
response_queue = queue.Queue(maxsize=20)
class ChatStreamer(BaseStreamer): class ChatStreamer(BaseStreamer):
def __init__(self, tokenizer) -> None: def __init__(self, tokenizer) -> None:
super().__init__() super().__init__()
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.queue = response_queue
self.query = query
self.history = history
self.response = ""
self.received_inputs = False
self.queue.put((self.response, history + [(self.query, self.response)]))
def put(self, value): def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1: if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("ChatStreamer only supports batch size 1") raise ValueError("ChatStreamer only supports batch size 1")
elif len(value.shape) > 1: elif len(value.shape) > 1:
value = value[0] value = value[0]
if not self.received_inputs:
# The first received value is input_ids, ignore here
self.received_inputs = True
return
token = self.tokenizer.decode([value[-1]], skip_special_tokens=True) token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
if token.strip() != "<eoa>": if token.strip() != "<eoa>":
print(token, end="") self.response = self.response + token
history = self.history + [(self.query, self.response)]
self.queue.put((self.response, history))
def end(self): def end(self):
print("") self.queue.put(None)
def stream_producer():
return self.chat( return self.chat(
tokenizer=tokenizer, tokenizer=tokenizer,
query=query, query=query,
@ -839,6 +864,17 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
**kwargs **kwargs
) )
def consumer():
producer = threading.Thread(target=stream_producer)
producer.start()
while True:
res = response_queue.get()
if res is not None:
return
yield res
return consumer()
@add_start_docstrings( @add_start_docstrings(
""" """