import copy from typing import Any, Callable, Dict, Tuple import pytest import torch import torch.nn as nn from coati.models.base import Actor, Critic, RewardModel, get_base_model from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic from coati.models.generation import generate from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.chatglm import ChatGLMActor from coati.models.lora import LoraLinear, convert_to_lora_module from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seq_len", [32]) @pytest.mark.parametrize( "actor_maker", [ lambda: BLOOMActor(), lambda: GPTActor(), # HACK: skip llama due to long execution time # lambda: LlamaActor(), lambda: OPTActor(), # lambda: ChatGLMActor(), ]) @pytest.mark.parametrize("generate_kwargs", [{ "max_length": 64, "use_cache": True, "do_sample": True, "temperature": 1.0, "top_k": 50, }]) def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): actor = actor_maker() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() sequences = generate(actor.cuda(), input_ids, **generate_kwargs) assert sequences.shape == (batch_size, generate_kwargs["max_length"]) def test_utils(): fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))} fn_output = masked_mean(dim=0, **fn_input) assert fn_output.dim() == 0 assert torch.allclose(fn_output, torch.tensor(1.0)) batch_size = 4 num_labels = 10 fn_input = { "r": torch.ones((batch_size,)), "kl_coef": 1.0, "log_probs": torch.randn((batch_size, num_labels)), "log_probs_base": torch.randn((batch_size, num_labels)), "action_mask": torch.randint(0, 2, (batch_size, num_labels)) } fn_output = compute_reward(**fn_input) assert fn_output.shape == (batch_size,) batch_size = 4 seq_len = 32 num_labels = 10 num_actions = 2 fn_input = { "output": { "logits": torch.randn((batch_size, seq_len, num_labels)) }, "sequences": torch.randint(0, num_labels, (batch_size, seq_len)), "num_actions": num_actions, } fn_output = calc_action_log_probs(**fn_input) assert fn_output.shape == (batch_size, num_actions) @pytest.mark.parametrize("lora_rank", [4]) @pytest.mark.parametrize("num_dim", [32]) @pytest.mark.parametrize("num_layers", [4]) def test_lora(lora_rank: int, num_dim: int, num_layers: int): model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)]) lora_model = convert_to_lora_module(model, lora_rank) assert isinstance(lora_model, nn.ModuleList) for i in range(num_layers): assert isinstance(lora_model[i], LoraLinear) assert lora_model[i].lora_A.shape == (lora_rank, num_dim) assert lora_model[i].lora_B.shape == (num_dim, lora_rank) old_model = copy.deepcopy(lora_model) for i in range(num_layers): assert isinstance(lora_model[i], LoraLinear) assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].bias, lora_model[i].bias) assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A) optimizer = torch.optim.Adam(lora_model.parameters()) x = torch.randn(8, num_dim) for i in range(num_layers): x = lora_model[i](x) loss = x.sum() loss.backward() optimizer.step() for i in range(num_layers): assert isinstance(lora_model[i], LoraLinear) assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].bias, lora_model[i].bias) assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [128]) @pytest.mark.parametrize( "models_maker", [ lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), lambda: (GPTActor(), GPTCritic(), GPTRM()), # HACK: skip llama due to long execution time # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), lambda: (OPTActor(), OPTCritic(), OPTRM()), lambda: (ChatGLMActor(), None, None), ]) @torch.no_grad() def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int): actor_input = { "input_ids": torch.randint(0, 100, (batch_size, seq_len)), "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) } critic_input = { "sequences": torch.randint(0, 100, (batch_size, seq_len)), "action_mask": torch.randint(0, 2, (batch_size, seq_len)), "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) } rm_input = { "sequences": torch.randint(0, 100, (batch_size, seq_len)), "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) } actor, critic, rm = models_maker() if isinstance(actor, ChatGLMActor): actor = actor.float() tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True) chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1) actor_input ={ "input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1), "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)) } assert isinstance(actor, Actor) base_actor_model = get_base_model(actor) actor_output = actor(**actor_input) assert actor_output.logits.shape[:2] == (batch_size, seq_len) if critic: assert isinstance(critic, Critic) base_critic_model = get_base_model(critic) critic_output = critic(**critic_input) assert critic_output.shape == (batch_size, ) if rm: assert isinstance(rm, RewardModel) base_rm_model = get_base_model(rm) rm_output = rm(**rm_input) assert rm_output.shape == (batch_size, ) @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("seq_len", [128]) @pytest.mark.parametrize("num_labels", [100]) def test_loss(batch_size: int, seq_len: int, num_labels: int): loss = GPTLMLoss() loss_input = { "logits": torch.randn(batch_size, seq_len, num_labels), "labels": torch.randint(0, num_labels, (batch_size, seq_len)) } loss_output = loss(**loss_input) loss = PolicyLoss() loss_input = { "log_probs": torch.randn(batch_size,), "old_log_probs": torch.randn(batch_size,), "advantages": torch.randn(batch_size,) } loss_output = loss(**loss_input) loss = ValueLoss() loss_input = { "values": torch.randn(batch_size,), "old_values": torch.randn(batch_size,), "reward": torch.randn(batch_size,) } loss_output = loss(**loss_input) loss = LogSigLoss() loss_input = { "chosen_reward": torch.randn(batch_size,), "reject_reward": torch.randn(batch_size,), } loss_output = loss(**loss_input) loss = LogExpLoss() loss_input = { "chosen_reward": torch.randn(batch_size,), "reject_reward": torch.randn(batch_size,), } loss_output = loss(**loss_input) if __name__ == "__main__": generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50) test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs) test_utils() test_lora(lora_rank=2, num_dim=8, num_layers=2) test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) test_loss(batch_size=8, seq_len=128, num_labels=100)