From 426a279ce7134e3c40c02dfed3d65dce56cf908b Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 6 Jul 2022 10:50:49 +0800 Subject: [PATCH] [fx] added testing for all bert variants (#1207) * [fx] added testing for all bert variants * polish code --- colossalai/fx/proxy.py | 18 +++- .../test_tracer/test_hf_model/test_hf_bert.py | 90 +++++++++++++++---- 2 files changed, 88 insertions(+), 20 deletions(-) diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index 50a004a12..2f8ca6d94 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -1,6 +1,8 @@ import operator import torch from torch.fx.proxy import Proxy, Attribute +from typing import List, Union +from torch.utils._pytree import PyTree __all__ = ['ColoProxy'] @@ -26,8 +28,12 @@ class ColoProxy(Proxy): return self._meta_tensor @meta_tensor.setter - def meta_tensor(self, tensor: torch.Tensor): - assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor' + def meta_tensor(self, tensor: Union[List[torch.Tensor], torch.Tensor]): + + def _is_meta(item): + assert torch.is_tensor(item) and item.is_meta + + torch.fx.node.map_aggregate(tensor, _is_meta) self._meta_tensor = tensor @property @@ -83,6 +89,14 @@ class ColoProxy(Proxy): def __setitem__(self, indices, values): return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {}) + def __contains__(self, key): + if self.node.op == "placeholder": + # this is used to handle like + # if x in kwargs + # we don't handle this case for now + return False + return super().__contains__(key) + class ColoAttribute(ColoProxy): diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 303d7d5f3..b9a855cef 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -7,36 +7,90 @@ BATCH_SIZE = 2 SEQ_LENGHT = 16 -def test_bert(): +def trace_bert_and_compare_output(model, data_gen): tracer = ColoTracer() - config = transformers.BertConfig() - model = transformers.BertModel(config=config) - - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta') - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta') - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta') - meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - # make sure that the model is traceable - graph = tracer.trace(root=model, meta_args=meta_args) + try: + kwargs = data_gen() + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() # check output - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + inputs = data_gen() # must turn on eval mode to ensure the output is consistent gm.eval() model.eval() # run forward - fx_out = gm(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - non_fx_out = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - assert fx_out['last_hidden_state'].shape == non_fx_out['last_hidden_state'].shape - assert torch.equal(fx_out['last_hidden_state'], non_fx_out['last_hidden_state']) + non_fx_out = model(**inputs) + fx_out = gm(**inputs) + + for k in non_fx_out.keys(): + assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}' + + +def test_single_sentence_bert(): + MODEL_LIST = [ + transformers.BertModel, + transformers.BertForPreTraining, + transformers.BertLMHeadModel, + transformers.BertForMaskedLM, + transformers.BertForSequenceClassification, + transformers.BertForTokenClassification, + ] + + config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + trace_bert_and_compare_output(model, data_gen) + + +def test_multi_sentence_bert(): + config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + + def data_gen_for_next_sentence(): + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + next_sentence = "The sky is blue due to the shorter wavelength of blue light." + encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + return encoding + + model = transformers.BertForNextSentencePrediction(config) + trace_bert_and_compare_output(model, data_gen_for_next_sentence) + + def data_gen_for_qa(): + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + inputs = tokenizer(question, text, return_tensors="pt") + return inputs + + model = transformers.BertForQuestionAnswering(config) + trace_bert_and_compare_output(model, data_gen_for_qa) + + def data_gen_for_mcq(): + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} + return encoding + + model = transformers.BertForMultipleChoice(config) + trace_bert_and_compare_output(model, data_gen_for_mcq) if __name__ == '__main__': - test_bert() + test_single_sentence_bert() + test_multi_sentence_bert()