diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index b36078f0a..3299b14a4 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -39,7 +39,7 @@ class ColoProxy(Proxy): self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}' def _assert_has_meta_data(self): - assert self._meta_data, f'Meta data is not set for {self.node.name}' + assert self._meta_data is not None, f'Meta data is not set for {self.node.name}' @property def device(self): @@ -63,7 +63,7 @@ class ColoProxy(Proxy): def size(self, dim: int = None): self._assert_meta_data_is_tensor() - if dim: + if dim is not None: return self.meta_data.size(dim=dim) else: # size(dim=None) will trigger runtime error for meta tensor diff --git a/colossalai/fx/tracer/meta_patch/patched_function.py b/colossalai/fx/tracer/meta_patch/patched_function.py index a5f57f6ac..fa1c7fd12 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function.py @@ -1,3 +1,4 @@ +from curses import meta import operator import torch from .registry import meta_patched_function @@ -88,4 +89,56 @@ def torch_arange(*args, **kwargs): def torch_where(condition, x, y): # torch.where returns the broadcasted tensor of condition, x, and y, # so hack it by using addition - return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") \ No newline at end of file + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") + + +@meta_patched_function.register(torch.abs) +def torch_abs(input, *, out=None): + assert out is None, 'out is not supported yet' + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.relu) +def torch_nn_func_relu(input, inplace=False): + assert not inplace, 'inplace is not supported yet' + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.Tensor.repeat) +def torch_tensor_repeat(self, *sizes): + shape = list(self.shape) + for i, x in enumerate(sizes): + shape[i] *= x + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.index_select) +def torch_index_select(input, dim, index, *, out=None): + shape = list(input.shape) + shape[dim] = len(index) + return torch.empty(*shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.index_select) +def torch_tensor_index_select(self, dim, index): + return torch_index_select(self, dim, index) + + +@meta_patched_function.register(torch.nn.functional.embedding) +def torch_nn_functional_embedding(input, + weight, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False): + return torch.empty(*input.shape, weight.shape[-1], device="meta") + + +@meta_patched_function.register(torch.bmm) +def torch_bmm(input, mat2, *, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + batch_size, n, m = input.shape + _, _, p = mat2.shape + return torch.empty(batch_size, n, p, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module.py b/colossalai/fx/tracer/meta_patch/patched_module.py index bf6bc33da..61fc341be 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module.py +++ b/colossalai/fx/tracer/meta_patch/patched_module.py @@ -116,3 +116,9 @@ def torch_nn_maxpool3d(self, input): w_out, ) return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.ReLU) +def torch_nn_func_relu(self, input): + assert not self.inplace, 'inplace is not supported yet' + return input.clone() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py new file mode 100644 index 000000000..e28d20dc3 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -0,0 +1,65 @@ +import transformers +import torch +from utils import trace_model_and_compare_output + +BATCH_SIZE = 2 +SEQ_LENGHT = 16 + + +def test_single_sentence_albert(): + MODEL_LIST = [ + transformers.AlbertModel, + transformers.AlbertForPreTraining, + transformers.AlbertForMaskedLM, + transformers.AlbertForSequenceClassification, + transformers.AlbertForTokenClassification, + ] + + config = transformers.AlbertConfig(embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return meta_args + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + trace_model_and_compare_output(model, data_gen) + + +def test_multi_sentence_albert(): + config = transformers.AlbertConfig(hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256) + tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + + def data_gen_for_qa(): + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + inputs = tokenizer(question, text, return_tensors="pt") + return inputs + + model = transformers.AlbertForQuestionAnswering(config) + trace_model_and_compare_output(model, data_gen_for_qa) + + def data_gen_for_mcq(): + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} + return encoding + + model = transformers.AlbertForMultipleChoice(config) + trace_model_and_compare_output(model, data_gen_for_mcq) + + +if __name__ == '__main__': + test_single_sentence_albert() + test_multi_sentence_albert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py new file mode 100644 index 000000000..78d17386f --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -0,0 +1,31 @@ +import pytest +import transformers +import torch +from utils import trace_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('value is not aligned yet') +def test_opt(): + MODEL_LIST = [ + transformers.OPTModel, + transformers.OPTForCausalLM, + ] + + config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + trace_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_opt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py new file mode 100644 index 000000000..001ada2db --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -0,0 +1,32 @@ +import pytest +import transformers +import torch +from utils import trace_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +@pytest.mark.skip('value is not aligned yet') +def test_t5(): + MODEL_LIST = [ + transformers.T5Model, + transformers.T5ForConditionalGeneration, + transformers.T5EncoderModel, + ] + + config = transformers.T5Config(d_model=128, num_layers=2) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + trace_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_t5() diff --git a/tests/test_fx/test_tracer/test_hf_model/utils.py b/tests/test_fx/test_tracer/test_hf_model/utils.py index cd5e54039..382c87ad5 100644 --- a/tests/test_fx/test_tracer/test_hf_model/utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/utils.py @@ -30,4 +30,6 @@ def trace_model_and_compare_output(model, data_gen): for k in non_fx_out.keys(): if torch.is_tensor(fx_out[k]): - assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}' + assert torch.equal( + fx_out[k], non_fx_out[k] + ), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}'