|
|
@ -2,7 +2,7 @@ import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel
|
|
|
|
from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel
|
|
|
|
|
|
|
|
|
|
|
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
|
|
|
# from tests.components_to_test.registry import non_distributed_component_funcs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPTLMModel(nn.Module):
|
|
|
|
class GPTLMModel(nn.Module):
|
|
|
@ -55,7 +55,7 @@ class BertLMModel(nn.Module):
|
|
|
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
|
|
|
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@non_distributed_component_funcs.register(name="bert_")
|
|
|
|
# @non_distributed_component_funcs.register(name="bert_")
|
|
|
|
def get_bert_components():
|
|
|
|
def get_bert_components():
|
|
|
|
vocab_size = 1024
|
|
|
|
vocab_size = 1024
|
|
|
|
seq_len = 64
|
|
|
|
seq_len = 64
|
|
|
@ -74,7 +74,7 @@ def get_bert_components():
|
|
|
|
return bert_model_builder, bert_data_gen
|
|
|
|
return bert_model_builder, bert_data_gen
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@non_distributed_component_funcs.register(name="gpt2_")
|
|
|
|
# @non_distributed_component_funcs.register(name="gpt2_")
|
|
|
|
def get_gpt2_components():
|
|
|
|
def get_gpt2_components():
|
|
|
|
vocab_size = 1024
|
|
|
|
vocab_size = 1024
|
|
|
|
seq_len = 8
|
|
|
|
seq_len = 8
|
|
|
|