[fx] added testing for all gpt variants (#1210)

* [fx] added testing for all gpt variants

* polish code

* polish code
pull/1211/head
Frank Lee 2022-07-06 14:03:13 +08:00 committed by GitHub
parent 189946c5c4
commit 2d13a45a3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 136 additions and 72 deletions

View File

@ -1,8 +1,7 @@
import operator import operator
import torch import torch
from torch.fx.proxy import Proxy, Attribute from torch.fx.proxy import Proxy, Attribute
from typing import List, Union from typing import List, Union, Any
from torch.utils._pytree import PyTree
__all__ = ['ColoProxy'] __all__ = ['ColoProxy']
@ -14,34 +13,33 @@ class ColoProxy(Proxy):
Usage: Usage:
proxy = tracer.create_proxy(...) proxy = tracer.create_proxy(...)
proxy.meta_tensor = torch.empty(4, 2, device='meta') proxy.meta_data = torch.empty(4, 2, device='meta')
print(len(proxy)) # expect output 4 print(len(proxy)) # expect output 4
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._meta_tensor = None self._meta_data = None
@property @property
def meta_tensor(self): def meta_data(self):
return self._meta_tensor return self._meta_data
@meta_tensor.setter @meta_data.setter
def meta_tensor(self, tensor: Union[List[torch.Tensor], torch.Tensor]): def meta_data(self, data: Any):
self._meta_data = data
def _is_meta(item):
assert torch.is_tensor(item) and item.is_meta
torch.fx.node.map_aggregate(tensor, _is_meta)
self._meta_tensor = tensor
@property @property
def has_meta_tensor(self): def has_meta_data(self):
return self.meta_tensor is not None return self._meta_data is not None
def _assert_has_meta(self): def _assert_meta_data_is_tensor(self):
assert self.has_meta_tensor, f'Meta tensor is not set for {self.node.name}' assert torch.is_tensor(
self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}'
def _assert_has_meta_data(self):
assert self._meta_data, f'Meta data is not set for {self.node.name}'
@property @property
def device(self): def device(self):
@ -51,37 +49,37 @@ class ColoProxy(Proxy):
@property @property
def dtype(self): def dtype(self):
self._assert_has_meta() self._assert_meta_data_is_tensor()
return self.meta_tensor.dtype return self.meta_data.dtype
@property @property
def shape(self): def shape(self):
self._assert_has_meta() self._assert_meta_data_is_tensor()
return self.meta_tensor.shape return self.meta_data.shape
def dim(self): def dim(self):
self._assert_has_meta() self._assert_meta_data_is_tensor()
return self.meta_tensor.dim() return self.meta_data.dim()
def size(self, dim: int = None): def size(self, dim: int = None):
self._assert_has_meta() self._assert_meta_data_is_tensor()
if dim: if dim:
return self.meta_tensor.size(dim=dim) return self.meta_data.size(dim=dim)
else: else:
# size(dim=None) will trigger runtime error for meta tensor # size(dim=None) will trigger runtime error for meta tensor
return self.meta_tensor.size() return self.meta_data.size()
def __len__(self): def __len__(self):
self._assert_has_meta() self._assert_has_meta_data()
return len(self.meta_tensor) return len(self.meta_data)
def __bool__(self): def __bool__(self):
self._assert_has_meta() self._assert_has_meta_data()
return self.meta_tensor return self.meta_data
def __getattr__(self, k): def __getattr__(self, k):
if k == "metadata": if k == "meta_data":
return self.meta_tensor return self.__getattribute__(k)
# note: not added to the graph yet, if this is a method call # note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation # we peephole optimize to the method invocation
return Attribute(self, k) return Attribute(self, k)

View File

@ -22,8 +22,8 @@ def extract_meta(*args, **kwargs):
if isinstance(val, MetaDeviceAttribute): if isinstance(val, MetaDeviceAttribute):
return 'meta' return 'meta'
elif isinstance(val, ColoProxy): elif isinstance(val, ColoProxy):
assert val.meta_tensor is not None assert val.meta_data is not None
return val.meta_tensor return val.meta_data
return val return val
new_args = [_convert(val) for val in args] new_args = [_convert(val) for val in args]

View File

@ -60,3 +60,32 @@ def torch_matmul(input, other, *, out=None):
if shape is None: if shape is None:
return torch.tensor(0.0, device="meta") return torch.tensor(0.0, device="meta")
return torch.empty(*shape, device="meta") return torch.empty(*shape, device="meta")
@meta_patched_function.register(torch.arange)
def torch_arange(*args, **kwargs):
n = len(args)
step = 1
if n == 1:
start = 0
end = args[0]
elif n == 2:
start, end = args
else:
start, end, step = args
if isinstance(start, float):
start = int(start)
if isinstance(end, float):
start = int(end)
if isinstance(step, float):
step = int(step)
step = kwargs.get("step", step)
dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta")
@meta_patched_function.register(torch.where)
def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")

View File

@ -14,7 +14,6 @@ from torch import Tensor
from torch.fx import Tracer from torch.fx import Tracer
from torch.fx.graph import Graph from torch.fx.graph import Graph
from torch.fx.proxy import Proxy, ParameterProxy from torch.fx.proxy import Proxy, ParameterProxy
from torch.utils import _pytree
from ..proxy import ColoProxy from ..proxy import ColoProxy
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from ._tracer_utils import is_element_in_list, extract_meta from ._tracer_utils import is_element_in_list, extract_meta
@ -62,7 +61,7 @@ class ColoTracer(Tracer):
proxy: ColoProxy proxy: ColoProxy
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
proxy.meta_tensor = self.meta_args[target] proxy.meta_data = self.meta_args[target]
return proxy return proxy
if target in self.orig_torch_tensor_methods: if target in self.orig_torch_tensor_methods:
@ -128,7 +127,7 @@ class ColoTracer(Tracer):
if not isinstance(proxy, Proxy): if not isinstance(proxy, Proxy):
raise ValueError("Don't support composite output yet") raise ValueError("Don't support composite output yet")
proxy.meta_tensor = meta_out proxy.meta_data = meta_out
except Exception as e: except Exception as e:
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
return proxy return proxy

View File

@ -10,7 +10,7 @@ def test_coloproxy():
# create proxy # create proxy
proxy = ColoProxy(node=node) proxy = ColoProxy(node=node)
proxy.meta_tensor = torch.empty(4, 2, device='meta') proxy.meta_data = torch.empty(4, 2, device='meta')
assert len(proxy) == 4 assert len(proxy) == 4
assert proxy.shape[0] == 4 and proxy.shape[1] == 2 assert proxy.shape[0] == 4 and proxy.shape[1] == 2

View File

@ -1,39 +1,11 @@
import transformers import transformers
import torch import torch
from colossalai.fx import ColoTracer from utils import trace_model_and_compare_output
from torch.fx import GraphModule
BATCH_SIZE = 2 BATCH_SIZE = 2
SEQ_LENGHT = 16 SEQ_LENGHT = 16
def trace_bert_and_compare_output(model, data_gen):
tracer = ColoTracer()
# make sure that the model is traceable
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# check output
inputs = data_gen()
# must turn on eval mode to ensure the output is consistent
gm.eval()
model.eval()
# run forward
non_fx_out = model(**inputs)
fx_out = gm(**inputs)
for k in non_fx_out.keys():
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'
def test_single_sentence_bert(): def test_single_sentence_bert():
MODEL_LIST = [ MODEL_LIST = [
transformers.BertModel, transformers.BertModel,
@ -55,7 +27,7 @@ def test_single_sentence_bert():
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
model = model_cls(config=config) model = model_cls(config=config)
trace_bert_and_compare_output(model, data_gen) trace_model_and_compare_output(model, data_gen)
def test_multi_sentence_bert(): def test_multi_sentence_bert():
@ -69,7 +41,7 @@ def test_multi_sentence_bert():
return encoding return encoding
model = transformers.BertForNextSentencePrediction(config) model = transformers.BertForNextSentencePrediction(config)
trace_bert_and_compare_output(model, data_gen_for_next_sentence) trace_model_and_compare_output(model, data_gen_for_next_sentence)
def data_gen_for_qa(): def data_gen_for_qa():
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
@ -77,7 +49,7 @@ def test_multi_sentence_bert():
return inputs return inputs
model = transformers.BertForQuestionAnswering(config) model = transformers.BertForQuestionAnswering(config)
trace_bert_and_compare_output(model, data_gen_for_qa) trace_model_and_compare_output(model, data_gen_for_qa)
def data_gen_for_mcq(): def data_gen_for_mcq():
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
@ -88,7 +60,7 @@ def test_multi_sentence_bert():
return encoding return encoding
model = transformers.BertForMultipleChoice(config) model = transformers.BertForMultipleChoice(config)
trace_bert_and_compare_output(model, data_gen_for_mcq) trace_model_and_compare_output(model, data_gen_for_mcq)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -0,0 +1,33 @@
import transformers
import torch
from utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
def test_gpt():
MODEL_LIST = [
transformers.GPT2Model,
transformers.GPT2LMHeadModel,
transformers.GPT2DoubleHeadsModel,
transformers.GPT2ForTokenClassification,
# transformers.GPT2ForSequenceClassification, # not supported yet
]
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_gpt()

View File

@ -0,0 +1,33 @@
from numpy import isin
import torch
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
def trace_model_and_compare_output(model, data_gen):
tracer = ColoTracer()
# make sure that the model is traceable
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# check output
inputs = data_gen()
# must turn on eval mode to ensure the output is consistent
gm.eval()
model.eval()
# run forward
non_fx_out = model(**inputs)
fx_out = gm(**inputs)
for k in non_fx_out.keys():
if torch.is_tensor(fx_out[k]):
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'