ColossalAI/applications/Chat/tests/test_models.py

221 lines
8.2 KiB
Python

import copy
from typing import Any, Callable, Dict, Tuple
import pytest
import torch
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.chatglm import ChatGLMActor
from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32])
@pytest.mark.parametrize(
"actor_maker",
[
lambda: BLOOMActor(),
lambda: GPTActor(),
# HACK: skip llama due to long execution time
# lambda: LlamaActor(),
lambda: OPTActor(),
# lambda: ChatGLMActor(),
])
@pytest.mark.parametrize("generate_kwargs", [{
"max_length": 64,
"use_cache": True,
"do_sample": True,
"temperature": 1.0,
"top_k": 50,
}])
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
def test_utils():
fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
fn_output = masked_mean(dim=0, **fn_input)
assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0))
batch_size = 4
num_labels = 10
fn_input = {
"r": torch.ones((batch_size,)),
"kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels))
}
fn_output = compute_reward(**fn_input)
assert fn_output.shape == (batch_size,)
batch_size = 4
seq_len = 32
num_labels = 10
num_actions = 2
fn_input = {
"output": {
"logits": torch.randn((batch_size, seq_len, num_labels))
},
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
"num_actions": num_actions,
}
fn_output = calc_action_log_probs(**fn_input)
assert fn_output.shape == (batch_size, num_actions)
@pytest.mark.parametrize("lora_rank", [4])
@pytest.mark.parametrize("num_dim", [32])
@pytest.mark.parametrize("num_layers", [4])
def test_lora(lora_rank: int, num_dim: int, num_layers: int):
model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
lora_model = convert_to_lora_module(model, lora_rank)
assert isinstance(lora_model, nn.ModuleList)
for i in range(num_layers):
assert isinstance(lora_model[i], LoraLinear)
assert lora_model[i].lora_A.shape == (lora_rank, num_dim)
assert lora_model[i].lora_B.shape == (num_dim, lora_rank)
old_model = copy.deepcopy(lora_model)
for i in range(num_layers):
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
optimizer = torch.optim.Adam(lora_model.parameters())
x = torch.randn(8, num_dim)
for i in range(num_layers):
x = lora_model[i](x)
loss = x.sum()
loss.backward()
optimizer.step()
for i in range(num_layers):
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
lora_model[i].lora_B @ lora_model[i].lora_A)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize(
"models_maker",
[
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()),
lambda: (ChatGLMActor(), None, None),
])
@torch.no_grad()
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
batch_size: int,
seq_len: int):
actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
}
critic_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
"action_mask": torch.randint(0, 2, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
}
rm_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
}
actor, critic, rm = models_maker()
if isinstance(actor, ChatGLMActor):
actor = actor.float()
tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True)
chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
actor_input ={
"input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1),
"attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))
}
assert isinstance(actor, Actor)
base_actor_model = get_base_model(actor)
actor_output = actor(**actor_input)
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
if critic:
assert isinstance(critic, Critic)
base_critic_model = get_base_model(critic)
critic_output = critic(**critic_input)
assert critic_output.shape == (batch_size, )
if rm:
assert isinstance(rm, RewardModel)
base_rm_model = get_base_model(rm)
rm_output = rm(**rm_input)
assert rm_output.shape == (batch_size, )
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("num_labels", [100])
def test_loss(batch_size: int, seq_len: int, num_labels: int):
loss = GPTLMLoss()
loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels),
"labels": torch.randint(0, num_labels, (batch_size, seq_len))
}
loss_output = loss(**loss_input)
loss = PolicyLoss()
loss_input = {
"log_probs": torch.randn(batch_size,),
"old_log_probs": torch.randn(batch_size,),
"advantages": torch.randn(batch_size,)
}
loss_output = loss(**loss_input)
loss = ValueLoss()
loss_input = {
"values": torch.randn(batch_size,),
"old_values": torch.randn(batch_size,),
"reward": torch.randn(batch_size,)
}
loss_output = loss(**loss_input)
loss = LogSigLoss()
loss_input = {
"chosen_reward": torch.randn(batch_size,),
"reject_reward": torch.randn(batch_size,),
}
loss_output = loss(**loss_input)
loss = LogExpLoss()
loss_input = {
"chosen_reward": torch.randn(batch_size,),
"reject_reward": torch.randn(batch_size,),
}
loss_output = loss(**loss_input)
if __name__ == "__main__":
generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
test_utils()
test_lora(lora_rank=2, num_dim=8, num_layers=2)
test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
test_loss(batch_size=8, seq_len=128, num_labels=100)