mirror of https://github.com/hpcaitech/ColossalAI
[fx] added testing for all bert variants (#1207)
* [fx] added testing for all bert variants * polish codepull/1208/head
parent
b5f25eb32a
commit
426a279ce7
|
@ -1,6 +1,8 @@
|
|||
import operator
|
||||
import torch
|
||||
from torch.fx.proxy import Proxy, Attribute
|
||||
from typing import List, Union
|
||||
from torch.utils._pytree import PyTree
|
||||
|
||||
__all__ = ['ColoProxy']
|
||||
|
||||
|
@ -26,8 +28,12 @@ class ColoProxy(Proxy):
|
|||
return self._meta_tensor
|
||||
|
||||
@meta_tensor.setter
|
||||
def meta_tensor(self, tensor: torch.Tensor):
|
||||
assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
|
||||
def meta_tensor(self, tensor: Union[List[torch.Tensor], torch.Tensor]):
|
||||
|
||||
def _is_meta(item):
|
||||
assert torch.is_tensor(item) and item.is_meta
|
||||
|
||||
torch.fx.node.map_aggregate(tensor, _is_meta)
|
||||
self._meta_tensor = tensor
|
||||
|
||||
@property
|
||||
|
@ -83,6 +89,14 @@ class ColoProxy(Proxy):
|
|||
def __setitem__(self, indices, values):
|
||||
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
|
||||
|
||||
def __contains__(self, key):
|
||||
if self.node.op == "placeholder":
|
||||
# this is used to handle like
|
||||
# if x in kwargs
|
||||
# we don't handle this case for now
|
||||
return False
|
||||
return super().__contains__(key)
|
||||
|
||||
|
||||
class ColoAttribute(ColoProxy):
|
||||
|
||||
|
|
|
@ -7,36 +7,90 @@ BATCH_SIZE = 2
|
|||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
def test_bert():
|
||||
def trace_bert_and_compare_output(model, data_gen):
|
||||
tracer = ColoTracer()
|
||||
config = transformers.BertConfig()
|
||||
model = transformers.BertModel(config=config)
|
||||
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
|
||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
# make sure that the model is traceable
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
try:
|
||||
kwargs = data_gen()
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# check output
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
inputs = data_gen()
|
||||
|
||||
# must turn on eval mode to ensure the output is consistent
|
||||
gm.eval()
|
||||
model.eval()
|
||||
|
||||
# run forward
|
||||
fx_out = gm(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
non_fx_out = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
assert fx_out['last_hidden_state'].shape == non_fx_out['last_hidden_state'].shape
|
||||
assert torch.equal(fx_out['last_hidden_state'], non_fx_out['last_hidden_state'])
|
||||
non_fx_out = model(**inputs)
|
||||
fx_out = gm(**inputs)
|
||||
|
||||
for k in non_fx_out.keys():
|
||||
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'
|
||||
|
||||
|
||||
def test_single_sentence_bert():
|
||||
MODEL_LIST = [
|
||||
transformers.BertModel,
|
||||
transformers.BertForPreTraining,
|
||||
transformers.BertLMHeadModel,
|
||||
transformers.BertForMaskedLM,
|
||||
transformers.BertForSequenceClassification,
|
||||
transformers.BertForTokenClassification,
|
||||
]
|
||||
|
||||
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
return meta_args
|
||||
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls(config=config)
|
||||
trace_bert_and_compare_output(model, data_gen)
|
||||
|
||||
|
||||
def test_multi_sentence_bert():
|
||||
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
|
||||
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
def data_gen_for_next_sentence():
|
||||
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
||||
return encoding
|
||||
|
||||
model = transformers.BertForNextSentencePrediction(config)
|
||||
trace_bert_and_compare_output(model, data_gen_for_next_sentence)
|
||||
|
||||
def data_gen_for_qa():
|
||||
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||
inputs = tokenizer(question, text, return_tensors="pt")
|
||||
return inputs
|
||||
|
||||
model = transformers.BertForQuestionAnswering(config)
|
||||
trace_bert_and_compare_output(model, data_gen_for_qa)
|
||||
|
||||
def data_gen_for_mcq():
|
||||
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
choice0 = "It is eaten with a fork and a knife."
|
||||
choice1 = "It is eaten while held in the hand."
|
||||
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
|
||||
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
|
||||
return encoding
|
||||
|
||||
model = transformers.BertForMultipleChoice(config)
|
||||
trace_bert_and_compare_output(model, data_gen_for_mcq)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bert()
|
||||
test_single_sentence_bert()
|
||||
test_multi_sentence_bert()
|
||||
|
|
Loading…
Reference in New Issue