mirror of https://github.com/hpcaitech/ColossalAI
[fx] supported model tracing for huggingface bert (#1201)
* [fx] supported model tracing for huggingface bert * polish testpull/1202/head
parent
060b917daf
commit
f7878f465c
|
@ -59,7 +59,11 @@ class ColoProxy(Proxy):
|
|||
|
||||
def size(self, dim: int = None):
|
||||
self._assert_has_meta()
|
||||
return self.meta_tensor.size(dim=dim)
|
||||
if dim:
|
||||
return self.meta_tensor.size(dim=dim)
|
||||
else:
|
||||
# size(dim=None) will trigger runtime error for meta tensor
|
||||
return self.meta_tensor.size()
|
||||
|
||||
def __len__(self):
|
||||
self._assert_has_meta()
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
import operator
|
||||
import torch
|
||||
from .registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(operator.getitem)
|
||||
def operator_getitem(a, b):
|
||||
# copied from huggingface.utils.fx
|
||||
def to_concrete(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
concrete = torch.ones_like(t, device="cpu")
|
||||
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
|
||||
concrete = concrete.to(torch.int64)
|
||||
return concrete
|
||||
return t
|
||||
|
||||
if isinstance(a, torch.Tensor):
|
||||
# TODO: infer shape without performing the computation.
|
||||
if isinstance(b, tuple):
|
||||
b = tuple(map(to_concrete, b))
|
||||
else:
|
||||
b = to_concrete(b)
|
||||
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||||
return operator.getitem(a, b)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.matmul)
|
||||
def torch_matmul(input, other, *, out=None):
|
||||
# copied from huggingface.utils.fx
|
||||
d1 = input.dim()
|
||||
d2 = other.dim()
|
||||
shape = None
|
||||
if d1 == 1 and d2 == 1:
|
||||
shape = None
|
||||
elif d1 == 2 and d2 == 2:
|
||||
shape = (input.size(0), other.size(1))
|
||||
elif d1 == 1 and d2 == 2:
|
||||
shape = (other.size(1),)
|
||||
elif d1 == 2 and d1 == 1:
|
||||
shape = (input.size(0),)
|
||||
else:
|
||||
max_length = max(input.dim(), other.dim())
|
||||
shape1 = list(input.shape)
|
||||
shape2 = list(other.shape)
|
||||
if d1 == 1:
|
||||
shape1 = [1] + shape1
|
||||
if d2 == 1:
|
||||
shape2.append(1)
|
||||
shape1 = [-1] * (max_length - d1) + list(input.shape)
|
||||
shape2 = [-1] * (max_length - d2) + list(other.shape)
|
||||
shape = []
|
||||
for i in range(max_length):
|
||||
shape.append(max(shape1[i], shape2[i]))
|
||||
shape[-2] = shape1[-2]
|
||||
shape[-1] = shape2[-1]
|
||||
if d1 == 1:
|
||||
shape.pop(-2)
|
||||
if d2 == 1:
|
||||
shape.pop(-1)
|
||||
if shape is None:
|
||||
return torch.tensor(0.0, device="meta")
|
||||
return torch.empty(*shape, device="meta")
|
|
@ -30,7 +30,7 @@ def torch_nn_normalize(self, input):
|
|||
|
||||
@meta_patched_module.register(torch.nn.Embedding)
|
||||
def torch_nn_embedding(self, input):
|
||||
result_shape = input.shape[:-1] + (self.embedding_dim,)
|
||||
result_shape = input.shape + (self.embedding_dim,)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
|
|
|
@ -198,6 +198,16 @@ class ColoTracer(Tracer):
|
|||
sig = inspect.signature(root.forward)
|
||||
sig_names = set(sig.parameters.keys())
|
||||
meta_arg_names = set(meta_args.keys())
|
||||
|
||||
# update concrete args with default values
|
||||
non_meta_arg_names = sig_names - meta_arg_names
|
||||
for k, v in sig.parameters.items():
|
||||
if k in non_meta_arg_names and \
|
||||
k not in concrete_args and \
|
||||
v.default is not inspect.Parameter.empty:
|
||||
concrete_args[k] = v.default
|
||||
|
||||
# get non concrete arg names
|
||||
concrete_arg_names = set(concrete_args.keys())
|
||||
non_concrete_arg_names = sig_names - concrete_arg_names
|
||||
|
||||
|
@ -213,8 +223,12 @@ class ColoTracer(Tracer):
|
|||
# assign as attributed for late reference
|
||||
def _check_kwargs(kwargs, should_be_meta: bool):
|
||||
for k, v in kwargs.items():
|
||||
assert v.is_meta == should_be_meta, \
|
||||
f'expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
|
||||
if not should_be_meta:
|
||||
assert not torch.is_tensor(v) or not v.is_meta, \
|
||||
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
|
||||
else:
|
||||
assert v.is_meta == should_be_meta, \
|
||||
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
|
||||
|
||||
_check_kwargs(concrete_args, should_be_meta=False)
|
||||
_check_kwargs(meta_args, should_be_meta=True)
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
import transformers
|
||||
import torch
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGHT = 16
|
||||
|
||||
|
||||
def test_bert():
|
||||
tracer = ColoTracer()
|
||||
config = transformers.BertConfig()
|
||||
model = transformers.BertModel(config=config)
|
||||
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64, device='meta')
|
||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
# make sure that the model is traceable
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# check output
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
||||
|
||||
# must turn on eval mode to ensure the output is consistent
|
||||
gm.eval()
|
||||
model.eval()
|
||||
|
||||
# run forward
|
||||
fx_out = gm(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
non_fx_out = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
assert fx_out['last_hidden_state'].shape == non_fx_out['last_hidden_state'].shape
|
||||
assert torch.equal(fx_out['last_hidden_state'], non_fx_out['last_hidden_state'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bert()
|
Loading…
Reference in New Issue