mirror of https://github.com/hpcaitech/ColossalAI
80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import GPT2Config, GPT2LMHeadModel
|
|
|
|
from tests.test_elixir.utils.registry import TEST_MODELS
|
|
|
|
MICRO_VS = 128
|
|
MICRO_BS = 4
|
|
MICRO_SL = 64
|
|
|
|
MACRO_VS = 50257
|
|
MACRO_BS = 2
|
|
MACRO_SL = 1024
|
|
|
|
|
|
def micro_data_fn():
|
|
input_ids = torch.randint(low=0, high=MICRO_VS, size=(MICRO_BS, MICRO_SL))
|
|
attn_mask = torch.ones_like(input_ids)
|
|
return dict(input_ids=input_ids, attention_mask=attn_mask)
|
|
|
|
|
|
def small_data_fn():
|
|
input_ids = torch.randint(low=0, high=MACRO_VS, size=(MACRO_BS, MACRO_SL))
|
|
attn_mask = torch.ones_like(input_ids)
|
|
return dict(input_ids=input_ids, attention_mask=attn_mask)
|
|
|
|
|
|
class GPTLMLoss(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.loss_fn = nn.CrossEntropyLoss()
|
|
|
|
def forward(self, logits, labels):
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
# Flatten the tokens
|
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
|
|
class GPTLMModel(nn.Module):
|
|
|
|
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257):
|
|
super().__init__()
|
|
self.enable_gc = False
|
|
self.config = GPT2Config(
|
|
# pre-commit: do not rearrange
|
|
n_embd=hidden_size,
|
|
n_layer=num_layers,
|
|
n_head=num_attention_heads,
|
|
n_positions=max_seq_len,
|
|
n_ctx=max_seq_len,
|
|
vocab_size=vocab_size,
|
|
resid_pdrop=0.0,
|
|
embd_pdrop=0.0,
|
|
attn_pdrop=0.0)
|
|
self.module = GPT2LMHeadModel(config=self.config)
|
|
self.criterion = GPTLMLoss()
|
|
|
|
def gradient_checkpointing_enable(self):
|
|
self.module.gradient_checkpointing_enable()
|
|
self.enable_gc = True
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
# Only return lm_logits
|
|
output = self.module(input_ids=input_ids, attention_mask=attention_mask, use_cache=(not self.enable_gc))[0]
|
|
loss = self.criterion(output, input_ids)
|
|
return loss
|
|
|
|
|
|
gpt2_micro = partial(GPTLMModel, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128)
|
|
gpt2_small = GPTLMModel
|
|
gpt2_base = partial(GPTLMModel, hidden_size=1024, num_layers=24, num_attention_heads=16)
|
|
|
|
TEST_MODELS.register('gpt2_micro', gpt2_micro, micro_data_fn)
|
|
TEST_MODELS.register('gpt2_small', gpt2_small, small_data_fn)
|
|
TEST_MODELS.register('gpt2_base', gpt2_base, small_data_fn)
|