mirror of https://github.com/hpcaitech/ColossalAI
[fx] added testing for all albert variants (#1211)
parent
2d13a45a3b
commit
5da87ce35d
|
@ -39,7 +39,7 @@ class ColoProxy(Proxy):
|
|||
self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
|
||||
|
||||
def _assert_has_meta_data(self):
|
||||
assert self._meta_data, f'Meta data is not set for {self.node.name}'
|
||||
assert self._meta_data is not None, f'Meta data is not set for {self.node.name}'
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
|
@ -63,7 +63,7 @@ class ColoProxy(Proxy):
|
|||
|
||||
def size(self, dim: int = None):
|
||||
self._assert_meta_data_is_tensor()
|
||||
if dim:
|
||||
if dim is not None:
|
||||
return self.meta_data.size(dim=dim)
|
||||
else:
|
||||
# size(dim=None) will trigger runtime error for meta tensor
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from curses import meta
|
||||
import operator
|
||||
import torch
|
||||
from .registry import meta_patched_function
|
||||
|
@ -88,4 +89,56 @@ def torch_arange(*args, **kwargs):
|
|||
def torch_where(condition, x, y):
|
||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||
# so hack it by using addition
|
||||
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
||||
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.abs)
|
||||
def torch_abs(input, *, out=None):
|
||||
assert out is None, 'out is not supported yet'
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.relu)
|
||||
def torch_nn_func_relu(input, inplace=False):
|
||||
assert not inplace, 'inplace is not supported yet'
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.repeat)
|
||||
def torch_tensor_repeat(self, *sizes):
|
||||
shape = list(self.shape)
|
||||
for i, x in enumerate(sizes):
|
||||
shape[i] *= x
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.index_select)
|
||||
def torch_index_select(input, dim, index, *, out=None):
|
||||
shape = list(input.shape)
|
||||
shape[dim] = len(index)
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.index_select)
|
||||
def torch_tensor_index_select(self, dim, index):
|
||||
return torch_index_select(self, dim, index)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.embedding)
|
||||
def torch_nn_functional_embedding(input,
|
||||
weight,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.0,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False):
|
||||
return torch.empty(*input.shape, weight.shape[-1], device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if out is not None:
|
||||
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||
batch_size, n, m = input.shape
|
||||
_, _, p = mat2.shape
|
||||
return torch.empty(batch_size, n, p, device="meta")
|
||||
|
|
|
@ -116,3 +116,9 @@ def torch_nn_maxpool3d(self, input):
|
|||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.ReLU)
|
||||
def torch_nn_func_relu(self, input):
|
||||
assert not self.inplace, 'inplace is not supported yet'
|
||||
return input.clone()
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
import transformers
|
||||
import torch
|
||||
from utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
def test_single_sentence_albert():
|
||||
MODEL_LIST = [
|
||||
transformers.AlbertModel,
|
||||
transformers.AlbertForPreTraining,
|
||||
transformers.AlbertForMaskedLM,
|
||||
transformers.AlbertForSequenceClassification,
|
||||
transformers.AlbertForTokenClassification,
|
||||
]
|
||||
|
||||
config = transformers.AlbertConfig(embedding_size=128,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
return meta_args
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
trace_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
def test_multi_sentence_albert():
|
||||
config = transformers.AlbertConfig(hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
def data_gen_for_qa():
|
||||
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||
inputs = tokenizer(question, text, return_tensors="pt")
|
||||
return inputs
|
||||
|
||||
model = transformers.AlbertForQuestionAnswering(config)
|
||||
trace_model_and_compare_output(model, data_gen_for_qa)
|
||||
|
||||
def data_gen_for_mcq():
|
||||
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
choice0 = "It is eaten with a fork and a knife."
|
||||
choice1 = "It is eaten while held in the hand."
|
||||
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
|
||||
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
|
||||
return encoding
|
||||
|
||||
model = transformers.AlbertForMultipleChoice(config)
|
||||
trace_model_and_compare_output(model, data_gen_for_mcq)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_single_sentence_albert()
|
||||
test_multi_sentence_albert()
|
|
@ -0,0 +1,31 @@
|
|||
import pytest
|
||||
import transformers
|
||||
import torch
|
||||
from utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('value is not aligned yet')
|
||||
def test_opt():
|
||||
MODEL_LIST = [
|
||||
transformers.OPTModel,
|
||||
transformers.OPTForCausalLM,
|
||||
]
|
||||
|
||||
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return kwargs
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
trace_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_opt()
|
|
@ -0,0 +1,32 @@
|
|||
import pytest
|
||||
import transformers
|
||||
import torch
|
||||
from utils import trace_model_and_compare_output
|
||||
|
||||
BATCH_SIZE = 1
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
@pytest.mark.skip('value is not aligned yet')
|
||||
def test_t5():
|
||||
MODEL_LIST = [
|
||||
transformers.T5Model,
|
||||
transformers.T5ForConditionalGeneration,
|
||||
transformers.T5EncoderModel,
|
||||
]
|
||||
|
||||
config = transformers.T5Config(d_model=128, num_layers=2)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
return kwargs
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
trace_model_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_t5()
|
|
@ -30,4 +30,6 @@ def trace_model_and_compare_output(model, data_gen):
|
|||
|
||||
for k in non_fx_out.keys():
|
||||
if torch.is_tensor(fx_out[k]):
|
||||
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'
|
||||
assert torch.equal(
|
||||
fx_out[k], non_fx_out[k]
|
||||
), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}'
|
||||
|
|
Loading…
Reference in New Issue