[fx] added testing for all albert variants (#1211)

pull/1213/head
Frank Lee 2022-07-06 15:11:08 +08:00 committed by GitHub
parent 2d13a45a3b
commit 5da87ce35d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 193 additions and 4 deletions

View File

@ -39,7 +39,7 @@ class ColoProxy(Proxy):
self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
def _assert_has_meta_data(self):
assert self._meta_data, f'Meta data is not set for {self.node.name}'
assert self._meta_data is not None, f'Meta data is not set for {self.node.name}'
@property
def device(self):
@ -63,7 +63,7 @@ class ColoProxy(Proxy):
def size(self, dim: int = None):
self._assert_meta_data_is_tensor()
if dim:
if dim is not None:
return self.meta_data.size(dim=dim)
else:
# size(dim=None) will trigger runtime error for meta tensor

View File

@ -1,3 +1,4 @@
from curses import meta
import operator
import torch
from .registry import meta_patched_function
@ -88,4 +89,56 @@ def torch_arange(*args, **kwargs):
def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
@meta_patched_function.register(torch.abs)
def torch_abs(input, *, out=None):
assert out is None, 'out is not supported yet'
return torch.empty(input.shape, device='meta')
@meta_patched_function.register(torch.nn.functional.relu)
def torch_nn_func_relu(input, inplace=False):
assert not inplace, 'inplace is not supported yet'
return torch.empty(input.shape, device='meta')
@meta_patched_function.register(torch.Tensor.repeat)
def torch_tensor_repeat(self, *sizes):
shape = list(self.shape)
for i, x in enumerate(sizes):
shape[i] *= x
return torch.empty(shape, device="meta")
@meta_patched_function.register(torch.index_select)
def torch_index_select(input, dim, index, *, out=None):
shape = list(input.shape)
shape[dim] = len(index)
return torch.empty(*shape, device="meta")
@meta_patched_function.register(torch.Tensor.index_select)
def torch_tensor_index_select(self, dim, index):
return torch_index_select(self, dim, index)
@meta_patched_function.register(torch.nn.functional.embedding)
def torch_nn_functional_embedding(input,
weight,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False):
return torch.empty(*input.shape, weight.shape[-1], device="meta")
@meta_patched_function.register(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = input.shape
_, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta")

View File

@ -116,3 +116,9 @@ def torch_nn_maxpool3d(self, input):
w_out,
)
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.ReLU)
def torch_nn_func_relu(self, input):
assert not self.inplace, 'inplace is not supported yet'
return input.clone()

View File

@ -0,0 +1,65 @@
import transformers
import torch
from utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
def test_single_sentence_albert():
MODEL_LIST = [
transformers.AlbertModel,
transformers.AlbertForPreTraining,
transformers.AlbertForMaskedLM,
transformers.AlbertForSequenceClassification,
transformers.AlbertForTokenClassification,
]
config = transformers.AlbertConfig(embedding_size=128,
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_model_and_compare_output(model, data_gen)
def test_multi_sentence_albert():
config = transformers.AlbertConfig(hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256)
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
def data_gen_for_qa():
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors="pt")
return inputs
model = transformers.AlbertForQuestionAnswering(config)
trace_model_and_compare_output(model, data_gen_for_qa)
def data_gen_for_mcq():
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
return encoding
model = transformers.AlbertForMultipleChoice(config)
trace_model_and_compare_output(model, data_gen_for_mcq)
if __name__ == '__main__':
test_single_sentence_albert()
test_multi_sentence_albert()

View File

@ -0,0 +1,31 @@
import pytest
import transformers
import torch
from utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('value is not aligned yet')
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
transformers.OPTForCausalLM,
]
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_opt()

View File

@ -0,0 +1,32 @@
import pytest
import transformers
import torch
from utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('value is not aligned yet')
def test_t5():
MODEL_LIST = [
transformers.T5Model,
transformers.T5ForConditionalGeneration,
transformers.T5EncoderModel,
]
config = transformers.T5Config(d_model=128, num_layers=2)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_t5()

View File

@ -30,4 +30,6 @@ def trace_model_and_compare_output(model, data_gen):
for k in non_fx_out.keys():
if torch.is_tensor(fx_out[k]):
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'
assert torch.equal(
fx_out[k], non_fx_out[k]
), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}'