mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.2 KiB
93 lines
3.2 KiB
import torch |
|
import torch.nn as nn |
|
from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel |
|
|
|
from tests.components_to_test.registry import non_distributed_component_funcs |
|
|
|
|
|
class GPTLMModel(nn.Module): |
|
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): |
|
super().__init__() |
|
self.model = GPT2LMHeadModel( |
|
GPT2Config( |
|
n_embd=hidden_size, |
|
n_layer=num_layers, |
|
n_head=num_attention_heads, |
|
n_positions=max_seq_len, |
|
n_ctx=max_seq_len, |
|
vocab_size=vocab_size, |
|
) |
|
) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
# Only return lm_logits |
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] |
|
|
|
|
|
class LMLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.loss_fn = nn.CrossEntropyLoss() |
|
|
|
def forward(self, logits, labels): |
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
# Flatten the tokens |
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
|
|
class BertLMModel(nn.Module): |
|
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522): |
|
super().__init__() |
|
self.model = BertLMHeadModel( |
|
BertConfig( |
|
n_embd=hidden_size, |
|
num_hidden_layers=num_layers, |
|
hidden_size=hidden_size, |
|
num_attention_heads=num_attention_heads, |
|
max_position_embeddings=hidden_size, |
|
vocab_size=vocab_size, |
|
) |
|
) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
# Only return lm_logits |
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] |
|
|
|
|
|
@non_distributed_component_funcs.register(name="bert_") |
|
def get_bert_components(): |
|
vocab_size = 1024 |
|
seq_len = 64 |
|
batchSize = 64 |
|
|
|
def bert_model_builder(): |
|
model = BertLMModel(hidden_size=8192, num_layers=4, num_attention_heads=32, vocab_size=vocab_size) |
|
return model |
|
|
|
def bert_data_gen(device="meta"): |
|
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device) |
|
attention_mask = torch.ones_like(input_ids, device=device) |
|
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) |
|
return kwargs |
|
|
|
return bert_model_builder, bert_data_gen |
|
|
|
|
|
@non_distributed_component_funcs.register(name="gpt2_") |
|
def get_gpt2_components(): |
|
vocab_size = 1024 |
|
seq_len = 8 |
|
batchSize = 64 |
|
|
|
def gpt2_model_builder(): |
|
model = GPTLMModel(hidden_size=8192, num_layers=2, num_attention_heads=32, vocab_size=vocab_size) |
|
return model |
|
|
|
def gpt2_data_gen(device="meta"): |
|
input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device) |
|
attention_mask = torch.ones_like(input_ids, device=device) |
|
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) |
|
return kwargs |
|
|
|
return gpt2_model_builder, gpt2_data_gen
|
|
|