2022-05-19 10:57:56 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from .registry import non_distributed_component_funcs
|
|
|
|
from transformers import GPT2Config, GPT2LMHeadModel
|
|
|
|
from .utils.dummy_data_generator import DummyDataGenerator
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
|
|
|
|
|
|
|
|
class DummyDataLoader(DummyDataGenerator):
|
2022-06-08 15:14:18 +00:00
|
|
|
vocab_size = 128
|
2022-05-19 10:57:56 +00:00
|
|
|
batch_size = 4
|
2022-06-08 15:14:18 +00:00
|
|
|
seq_len = 64
|
2022-05-19 10:57:56 +00:00
|
|
|
|
|
|
|
def generate(self):
|
|
|
|
input_ids = torch.randint(0,
|
|
|
|
DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len),
|
|
|
|
device=get_current_device())
|
|
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
return input_ids, attention_mask
|
|
|
|
|
|
|
|
|
|
|
|
class GPTLMModel(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
hidden_size=768,
|
|
|
|
num_layers=12,
|
|
|
|
num_attention_heads=12,
|
|
|
|
max_seq_len=1024,
|
|
|
|
vocab_size=50304,
|
|
|
|
checkpoint=False):
|
|
|
|
super().__init__()
|
|
|
|
self.checkpoint = checkpoint
|
|
|
|
self.model = GPT2LMHeadModel(
|
|
|
|
GPT2Config(n_embd=hidden_size,
|
|
|
|
n_layer=num_layers,
|
|
|
|
n_head=num_attention_heads,
|
|
|
|
n_positions=max_seq_len,
|
|
|
|
n_ctx=max_seq_len,
|
|
|
|
vocab_size=vocab_size,
|
|
|
|
resid_pdrop=0.0,
|
|
|
|
embd_pdrop=0.0,
|
|
|
|
attn_pdrop=0.0))
|
|
|
|
if checkpoint:
|
|
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
|
|
# Only return lm_logits
|
|
|
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
|
|
|
|
2022-09-08 08:34:09 +00:00
|
|
|
|
2022-06-08 15:14:18 +00:00
|
|
|
def gpt2_micro(checkpoint=True):
|
2022-09-08 08:34:09 +00:00
|
|
|
return GPTLMModel(checkpoint=checkpoint,
|
|
|
|
hidden_size=32,
|
|
|
|
num_layers=2,
|
|
|
|
num_attention_heads=4,
|
|
|
|
max_seq_len=64,
|
|
|
|
vocab_size=128)
|
|
|
|
|
2022-05-19 10:57:56 +00:00
|
|
|
|
|
|
|
def gpt2_s(checkpoint=True):
|
|
|
|
return GPTLMModel(checkpoint=checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
def gpt2_m(checkpoint=True):
|
|
|
|
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
class GPTLMLoss(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.loss_fn = nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
def forward(self, logits, labels):
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
# Flatten the tokens
|
|
|
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
|
|
|
|
|
|
|
|
|
@non_distributed_component_funcs.register(name='gpt2')
|
|
|
|
def get_training_components():
|
|
|
|
|
|
|
|
trainloader = DummyDataLoader()
|
|
|
|
testloader = DummyDataLoader()
|
|
|
|
|
|
|
|
criterion = GPTLMLoss()
|
2022-06-08 15:14:18 +00:00
|
|
|
return gpt2_micro, trainloader, testloader, torch.optim.Adam, criterion
|