[fx] supported model tracing for huggingface bert (#1201)

* [fx] supported model tracing for huggingface bert

* polish test
pull/1202/head
Frank Lee 2022-07-05 13:19:57 +08:00 committed by GitHub
parent 060b917daf
commit f7878f465c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 126 additions and 4 deletions

View File

@ -59,7 +59,11 @@ class ColoProxy(Proxy):
def size(self, dim: int = None): def size(self, dim: int = None):
self._assert_has_meta() self._assert_has_meta()
return self.meta_tensor.size(dim=dim) if dim:
return self.meta_tensor.size(dim=dim)
else:
# size(dim=None) will trigger runtime error for meta tensor
return self.meta_tensor.size()
def __len__(self): def __len__(self):
self._assert_has_meta() self._assert_has_meta()

View File

@ -0,0 +1,62 @@
import operator
import torch
from .registry import meta_patched_function
@meta_patched_function.register(operator.getitem)
def operator_getitem(a, b):
# copied from huggingface.utils.fx
def to_concrete(t):
if isinstance(t, torch.Tensor):
concrete = torch.ones_like(t, device="cpu")
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
concrete = concrete.to(torch.int64)
return concrete
return t
if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
return operator.getitem(a, b)
@meta_patched_function.register(torch.matmul)
def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx
d1 = input.dim()
d2 = other.dim()
shape = None
if d1 == 1 and d2 == 1:
shape = None
elif d1 == 2 and d2 == 2:
shape = (input.size(0), other.size(1))
elif d1 == 1 and d2 == 2:
shape = (other.size(1),)
elif d1 == 2 and d1 == 1:
shape = (input.size(0),)
else:
max_length = max(input.dim(), other.dim())
shape1 = list(input.shape)
shape2 = list(other.shape)
if d1 == 1:
shape1 = [1] + shape1
if d2 == 1:
shape2.append(1)
shape1 = [-1] * (max_length - d1) + list(input.shape)
shape2 = [-1] * (max_length - d2) + list(other.shape)
shape = []
for i in range(max_length):
shape.append(max(shape1[i], shape2[i]))
shape[-2] = shape1[-2]
shape[-1] = shape2[-1]
if d1 == 1:
shape.pop(-2)
if d2 == 1:
shape.pop(-1)
if shape is None:
return torch.tensor(0.0, device="meta")
return torch.empty(*shape, device="meta")

View File

@ -30,7 +30,7 @@ def torch_nn_normalize(self, input):
@meta_patched_module.register(torch.nn.Embedding) @meta_patched_module.register(torch.nn.Embedding)
def torch_nn_embedding(self, input): def torch_nn_embedding(self, input):
result_shape = input.shape[:-1] + (self.embedding_dim,) result_shape = input.shape + (self.embedding_dim,)
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device='meta')

View File

@ -198,6 +198,16 @@ class ColoTracer(Tracer):
sig = inspect.signature(root.forward) sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys()) sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys()) meta_arg_names = set(meta_args.keys())
# update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items():
if k in non_meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
# get non concrete arg names
concrete_arg_names = set(concrete_args.keys()) concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names non_concrete_arg_names = sig_names - concrete_arg_names
@ -213,8 +223,12 @@ class ColoTracer(Tracer):
# assign as attributed for late reference # assign as attributed for late reference
def _check_kwargs(kwargs, should_be_meta: bool): def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items(): for k, v in kwargs.items():
assert v.is_meta == should_be_meta, \ if not should_be_meta:
f'expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' assert not torch.is_tensor(v) or not v.is_meta, \
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
else:
assert v.is_meta == should_be_meta, \
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
_check_kwargs(concrete_args, should_be_meta=False) _check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True) _check_kwargs(meta_args, should_be_meta=True)

View File

@ -0,0 +1,42 @@
import transformers
import torch
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
BATCH_SIZE = 2
SEQ_LENGHT = 16
def test_bert():
tracer = ColoTracer()
config = transformers.BertConfig()
model = transformers.BertModel(config=config)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
# make sure that the model is traceable
graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# check output
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
# must turn on eval mode to ensure the output is consistent
gm.eval()
model.eval()
# run forward
fx_out = gm(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
non_fx_out = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
assert fx_out['last_hidden_state'].shape == non_fx_out['last_hidden_state'].shape
assert torch.equal(fx_out['last_hidden_state'], non_fx_out['last_hidden_state'])
if __name__ == '__main__':
test_bert()