ColossalAI/applications/Chat/tests/test_models.py

253 lines
8.4 KiB
Python

import copy
from typing import Any, Callable, Dict, Tuple
import pytest
import torch
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel, get_base_model
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.chatglm import ChatGLMActor
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor
from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32])
@pytest.mark.parametrize(
"actor_maker",
[
lambda: BLOOMActor(),
lambda: GPTActor(),
# HACK: skip llama due to long execution time
# lambda: LlamaActor(),
lambda: OPTActor(),
# lambda: ChatGLMActor(),
],
)
@pytest.mark.parametrize(
"generate_kwargs",
[
{
"max_length": 64,
"use_cache": True,
"do_sample": True,
"temperature": 1.0,
"top_k": 50,
}
],
)
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
def test_utils():
fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
fn_output = masked_mean(dim=0, **fn_input)
assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0))
batch_size = 4
num_labels = 10
fn_input = {
"r": torch.ones((batch_size,)),
"kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels)),
}
fn_output = compute_reward(**fn_input)
assert fn_output.shape == (batch_size,)
batch_size = 4
seq_len = 32
num_labels = 10
num_actions = 2
fn_input = {
"output": {"logits": torch.randn((batch_size, seq_len, num_labels))},
"sequences": torch.randint(0, num_labels, (batch_size, seq_len)),
"num_actions": num_actions,
}
fn_output = calc_action_log_probs(**fn_input)
assert fn_output.shape == (batch_size, num_actions)
@pytest.mark.parametrize("lora_rank", [4])
@pytest.mark.parametrize("num_dim", [32])
@pytest.mark.parametrize("num_layers", [4])
def test_lora(lora_rank: int, num_dim: int, num_layers: int):
model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
lora_model = convert_to_lora_module(model, lora_rank)
assert isinstance(lora_model, nn.ModuleList)
for i in range(num_layers):
assert isinstance(lora_model[i], LoraLinear)
assert lora_model[i].lora_A.shape == (lora_rank, num_dim)
assert lora_model[i].lora_B.shape == (num_dim, lora_rank)
old_model = copy.deepcopy(lora_model)
for i in range(num_layers):
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
optimizer = torch.optim.Adam(lora_model.parameters())
x = torch.randn(8, num_dim)
for i in range(num_layers):
x = lora_model[i](x)
loss = x.sum()
loss.backward()
optimizer.step()
for i in range(num_layers):
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert not torch.allclose(
old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A
)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize(
"models_maker",
[
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()),
lambda: (ChatGLMActor(), None, None),
],
)
@torch.no_grad()
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int):
actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
critic_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
"action_mask": torch.randint(0, 2, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
rm_input = {
"sequences": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len)),
}
actor, critic, rm = models_maker()
if isinstance(actor, ChatGLMActor):
actor = actor.float()
tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
actor_input = {
"input_ids": torch.cat(
(
torch.randint(0, 100, (batch_size, seq_len // 2)),
chatglm_special_token,
torch.randint(0, 100, (batch_size, seq_len // 2 - 2)),
),
dim=1,
),
"attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)),
}
assert isinstance(actor, Actor)
get_base_model(actor)
actor_output = actor(**actor_input)
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
if critic:
assert isinstance(critic, Critic)
get_base_model(critic)
critic_output = critic(**critic_input)
assert critic_output.shape == (batch_size,)
if rm:
assert isinstance(rm, RewardModel)
get_base_model(rm)
rm_output = rm(**rm_input)
assert rm_output.shape == (batch_size,)
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("num_labels", [100])
def test_loss(batch_size: int, seq_len: int, num_labels: int):
loss = GPTLMLoss()
loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels),
"labels": torch.randint(0, num_labels, (batch_size, seq_len)),
}
loss(**loss_input)
loss = PolicyLoss()
loss_input = {
"log_probs": torch.randn(
batch_size,
),
"old_log_probs": torch.randn(
batch_size,
),
"advantages": torch.randn(
batch_size,
),
}
loss(**loss_input)
loss = ValueLoss()
loss_input = {
"values": torch.randn(
batch_size,
),
"old_values": torch.randn(
batch_size,
),
"reward": torch.randn(
batch_size,
),
}
loss(**loss_input)
loss = LogSigLoss()
loss_input = {
"chosen_reward": torch.randn(
batch_size,
),
"reject_reward": torch.randn(
batch_size,
),
}
loss(**loss_input)
loss = LogExpLoss()
loss_input = {
"chosen_reward": torch.randn(
batch_size,
),
"reject_reward": torch.randn(
batch_size,
),
}
loss(**loss_input)
if __name__ == "__main__":
generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
test_utils()
test_lora(lora_rank=2, num_dim=8, num_layers=2)
test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
test_loss(batch_size=8, seq_len=128, num_labels=100)