From f7878f465c92215cb40cf1b6b095c0c993bd90a1 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 5 Jul 2022 13:19:57 +0800 Subject: [PATCH] [fx] supported model tracing for huggingface bert (#1201) * [fx] supported model tracing for huggingface bert * polish test --- colossalai/fx/proxy.py | 6 +- .../fx/tracer/meta_patch/patched_function.py | 62 +++++++++++++++++++ .../fx/tracer/meta_patch/patched_module.py | 2 +- colossalai/fx/tracer/tracer.py | 18 +++++- .../test_tracer/test_hf_model/test_hf_bert.py | 42 +++++++++++++ 5 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index 72f9e646c..50a004a12 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -59,7 +59,11 @@ class ColoProxy(Proxy): def size(self, dim: int = None): self._assert_has_meta() - return self.meta_tensor.size(dim=dim) + if dim: + return self.meta_tensor.size(dim=dim) + else: + # size(dim=None) will trigger runtime error for meta tensor + return self.meta_tensor.size() def __len__(self): self._assert_has_meta() diff --git a/colossalai/fx/tracer/meta_patch/patched_function.py b/colossalai/fx/tracer/meta_patch/patched_function.py index e69de29bb..168a7bf95 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function.py @@ -0,0 +1,62 @@ +import operator +import torch +from .registry import meta_patched_function + + +@meta_patched_function.register(operator.getitem) +def operator_getitem(a, b): + # copied from huggingface.utils.fx + def to_concrete(t): + if isinstance(t, torch.Tensor): + concrete = torch.ones_like(t, device="cpu") + if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]: + concrete = concrete.to(torch.int64) + return concrete + return t + + if isinstance(a, torch.Tensor): + # TODO: infer shape without performing the computation. + if isinstance(b, tuple): + b = tuple(map(to_concrete, b)) + else: + b = to_concrete(b) + return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") + return operator.getitem(a, b) + + +@meta_patched_function.register(torch.matmul) +def torch_matmul(input, other, *, out=None): + # copied from huggingface.utils.fx + d1 = input.dim() + d2 = other.dim() + shape = None + if d1 == 1 and d2 == 1: + shape = None + elif d1 == 2 and d2 == 2: + shape = (input.size(0), other.size(1)) + elif d1 == 1 and d2 == 2: + shape = (other.size(1),) + elif d1 == 2 and d1 == 1: + shape = (input.size(0),) + else: + max_length = max(input.dim(), other.dim()) + shape1 = list(input.shape) + shape2 = list(other.shape) + if d1 == 1: + shape1 = [1] + shape1 + if d2 == 1: + shape2.append(1) + shape1 = [-1] * (max_length - d1) + list(input.shape) + shape2 = [-1] * (max_length - d2) + list(other.shape) + shape = [] + for i in range(max_length): + shape.append(max(shape1[i], shape2[i])) + shape[-2] = shape1[-2] + shape[-1] = shape2[-1] + if d1 == 1: + shape.pop(-2) + if d2 == 1: + shape.pop(-1) + if shape is None: + return torch.tensor(0.0, device="meta") + return torch.empty(*shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module.py b/colossalai/fx/tracer/meta_patch/patched_module.py index e3ece40df..bf6bc33da 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module.py +++ b/colossalai/fx/tracer/meta_patch/patched_module.py @@ -30,7 +30,7 @@ def torch_nn_normalize(self, input): @meta_patched_module.register(torch.nn.Embedding) def torch_nn_embedding(self, input): - result_shape = input.shape[:-1] + (self.embedding_dim,) + result_shape = input.shape + (self.embedding_dim,) return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index dfeaa8b5c..de39f745e 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -198,6 +198,16 @@ class ColoTracer(Tracer): sig = inspect.signature(root.forward) sig_names = set(sig.parameters.keys()) meta_arg_names = set(meta_args.keys()) + + # update concrete args with default values + non_meta_arg_names = sig_names - meta_arg_names + for k, v in sig.parameters.items(): + if k in non_meta_arg_names and \ + k not in concrete_args and \ + v.default is not inspect.Parameter.empty: + concrete_args[k] = v.default + + # get non concrete arg names concrete_arg_names = set(concrete_args.keys()) non_concrete_arg_names = sig_names - concrete_arg_names @@ -213,8 +223,12 @@ class ColoTracer(Tracer): # assign as attributed for late reference def _check_kwargs(kwargs, should_be_meta: bool): for k, v in kwargs.items(): - assert v.is_meta == should_be_meta, \ - f'expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' + if not should_be_meta: + assert not torch.is_tensor(v) or not v.is_meta, \ + f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer' + else: + assert v.is_meta == should_be_meta, \ + f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' _check_kwargs(concrete_args, should_be_meta=False) _check_kwargs(meta_args, should_be_meta=True) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py new file mode 100644 index 000000000..303d7d5f3 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -0,0 +1,42 @@ +import transformers +import torch +from colossalai.fx import ColoTracer +from torch.fx import GraphModule + +BATCH_SIZE = 2 +SEQ_LENGHT = 16 + + +def test_bert(): + tracer = ColoTracer() + config = transformers.BertConfig() + model = transformers.BertModel(config=config) + + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta') + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta') + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta') + meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + # make sure that the model is traceable + graph = tracer.trace(root=model, meta_args=meta_args) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # check output + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + + # must turn on eval mode to ensure the output is consistent + gm.eval() + model.eval() + + # run forward + fx_out = gm(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + non_fx_out = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + assert fx_out['last_hidden_state'].shape == non_fx_out['last_hidden_state'].shape + assert torch.equal(fx_out['last_hidden_state'], non_fx_out['last_hidden_state']) + + +if __name__ == '__main__': + test_bert()