diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index 2f8ca6d94..b36078f0a 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -1,8 +1,7 @@ import operator import torch from torch.fx.proxy import Proxy, Attribute -from typing import List, Union -from torch.utils._pytree import PyTree +from typing import List, Union, Any __all__ = ['ColoProxy'] @@ -14,34 +13,33 @@ class ColoProxy(Proxy): Usage: proxy = tracer.create_proxy(...) - proxy.meta_tensor = torch.empty(4, 2, device='meta') + proxy.meta_data = torch.empty(4, 2, device='meta') print(len(proxy)) # expect output 4 """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._meta_tensor = None + self._meta_data = None @property - def meta_tensor(self): - return self._meta_tensor + def meta_data(self): + return self._meta_data - @meta_tensor.setter - def meta_tensor(self, tensor: Union[List[torch.Tensor], torch.Tensor]): - - def _is_meta(item): - assert torch.is_tensor(item) and item.is_meta - - torch.fx.node.map_aggregate(tensor, _is_meta) - self._meta_tensor = tensor + @meta_data.setter + def meta_data(self, data: Any): + self._meta_data = data @property - def has_meta_tensor(self): - return self.meta_tensor is not None + def has_meta_data(self): + return self._meta_data is not None - def _assert_has_meta(self): - assert self.has_meta_tensor, f'Meta tensor is not set for {self.node.name}' + def _assert_meta_data_is_tensor(self): + assert torch.is_tensor( + self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}' + + def _assert_has_meta_data(self): + assert self._meta_data, f'Meta data is not set for {self.node.name}' @property def device(self): @@ -51,37 +49,37 @@ class ColoProxy(Proxy): @property def dtype(self): - self._assert_has_meta() - return self.meta_tensor.dtype + self._assert_meta_data_is_tensor() + return self.meta_data.dtype @property def shape(self): - self._assert_has_meta() - return self.meta_tensor.shape + self._assert_meta_data_is_tensor() + return self.meta_data.shape def dim(self): - self._assert_has_meta() - return self.meta_tensor.dim() + self._assert_meta_data_is_tensor() + return self.meta_data.dim() def size(self, dim: int = None): - self._assert_has_meta() + self._assert_meta_data_is_tensor() if dim: - return self.meta_tensor.size(dim=dim) + return self.meta_data.size(dim=dim) else: # size(dim=None) will trigger runtime error for meta tensor - return self.meta_tensor.size() + return self.meta_data.size() def __len__(self): - self._assert_has_meta() - return len(self.meta_tensor) + self._assert_has_meta_data() + return len(self.meta_data) def __bool__(self): - self._assert_has_meta() - return self.meta_tensor + self._assert_has_meta_data() + return self.meta_data def __getattr__(self, k): - if k == "metadata": - return self.meta_tensor + if k == "meta_data": + return self.__getattribute__(k) # note: not added to the graph yet, if this is a method call # we peephole optimize to the method invocation return Attribute(self, k) diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index 05ecd3843..528c4a8e9 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -22,8 +22,8 @@ def extract_meta(*args, **kwargs): if isinstance(val, MetaDeviceAttribute): return 'meta' elif isinstance(val, ColoProxy): - assert val.meta_tensor is not None - return val.meta_tensor + assert val.meta_data is not None + return val.meta_data return val new_args = [_convert(val) for val in args] diff --git a/colossalai/fx/tracer/meta_patch/patched_function.py b/colossalai/fx/tracer/meta_patch/patched_function.py index 168a7bf95..a5f57f6ac 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function.py @@ -60,3 +60,32 @@ def torch_matmul(input, other, *, out=None): if shape is None: return torch.tensor(0.0, device="meta") return torch.empty(*shape, device="meta") + + +@meta_patched_function.register(torch.arange) +def torch_arange(*args, **kwargs): + n = len(args) + step = 1 + if n == 1: + start = 0 + end = args[0] + elif n == 2: + start, end = args + else: + start, end, step = args + if isinstance(start, float): + start = int(start) + if isinstance(end, float): + start = int(end) + if isinstance(step, float): + step = int(step) + step = kwargs.get("step", step) + dtype = kwargs.get("dtype") + return torch.empty((end - start) // step, dtype=dtype, device="meta") + + +@meta_patched_function.register(torch.where) +def torch_where(condition, x, y): + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") \ No newline at end of file diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index de39f745e..e4191f88c 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -14,7 +14,6 @@ from torch import Tensor from torch.fx import Tracer from torch.fx.graph import Graph from torch.fx.proxy import Proxy, ParameterProxy -from torch.utils import _pytree from ..proxy import ColoProxy from typing import Optional, Dict, Any from ._tracer_utils import is_element_in_list, extract_meta @@ -62,7 +61,7 @@ class ColoTracer(Tracer): proxy: ColoProxy if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: - proxy.meta_tensor = self.meta_args[target] + proxy.meta_data = self.meta_args[target] return proxy if target in self.orig_torch_tensor_methods: @@ -128,7 +127,7 @@ class ColoTracer(Tracer): if not isinstance(proxy, Proxy): raise ValueError("Don't support composite output yet") - proxy.meta_tensor = meta_out + proxy.meta_data = meta_out except Exception as e: raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}") return proxy diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index a1c75d168..ea3a064e3 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -10,7 +10,7 @@ def test_coloproxy(): # create proxy proxy = ColoProxy(node=node) - proxy.meta_tensor = torch.empty(4, 2, device='meta') + proxy.meta_data = torch.empty(4, 2, device='meta') assert len(proxy) == 4 assert proxy.shape[0] == 4 and proxy.shape[1] == 2 diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index b9a855cef..9bf600625 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -1,39 +1,11 @@ import transformers import torch -from colossalai.fx import ColoTracer -from torch.fx import GraphModule +from utils import trace_model_and_compare_output BATCH_SIZE = 2 SEQ_LENGHT = 16 -def trace_bert_and_compare_output(model, data_gen): - tracer = ColoTracer() - # make sure that the model is traceable - try: - kwargs = data_gen() - meta_args = {k: v.to('meta') for k, v in kwargs.items()} - graph = tracer.trace(root=model, meta_args=meta_args) - except Exception as e: - raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") - gm = GraphModule(model, graph, model.__class__.__name__) - gm.recompile() - - # check output - inputs = data_gen() - - # must turn on eval mode to ensure the output is consistent - gm.eval() - model.eval() - - # run forward - non_fx_out = model(**inputs) - fx_out = gm(**inputs) - - for k in non_fx_out.keys(): - assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}' - - def test_single_sentence_bert(): MODEL_LIST = [ transformers.BertModel, @@ -55,7 +27,7 @@ def test_single_sentence_bert(): for model_cls in MODEL_LIST: model = model_cls(config=config) - trace_bert_and_compare_output(model, data_gen) + trace_model_and_compare_output(model, data_gen) def test_multi_sentence_bert(): @@ -69,7 +41,7 @@ def test_multi_sentence_bert(): return encoding model = transformers.BertForNextSentencePrediction(config) - trace_bert_and_compare_output(model, data_gen_for_next_sentence) + trace_model_and_compare_output(model, data_gen_for_next_sentence) def data_gen_for_qa(): question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" @@ -77,7 +49,7 @@ def test_multi_sentence_bert(): return inputs model = transformers.BertForQuestionAnswering(config) - trace_bert_and_compare_output(model, data_gen_for_qa) + trace_model_and_compare_output(model, data_gen_for_qa) def data_gen_for_mcq(): prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." @@ -88,7 +60,7 @@ def test_multi_sentence_bert(): return encoding model = transformers.BertForMultipleChoice(config) - trace_bert_and_compare_output(model, data_gen_for_mcq) + trace_model_and_compare_output(model, data_gen_for_mcq) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py new file mode 100644 index 000000000..abd8b8ae0 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -0,0 +1,33 @@ +import transformers +import torch +from utils import trace_model_and_compare_output + +BATCH_SIZE = 1 +SEQ_LENGHT = 16 + + +def test_gpt(): + MODEL_LIST = [ + transformers.GPT2Model, + transformers.GPT2LMHeadModel, + transformers.GPT2DoubleHeadsModel, + transformers.GPT2ForTokenClassification, + # transformers.GPT2ForSequenceClassification, # not supported yet + ] + + config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) + + def data_gen(): + input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) + kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + return kwargs + + for model_cls in MODEL_LIST: + model = model_cls(config=config) + trace_model_and_compare_output(model, data_gen) + + +if __name__ == '__main__': + test_gpt() diff --git a/tests/test_fx/test_tracer/test_hf_model/utils.py b/tests/test_fx/test_tracer/test_hf_model/utils.py new file mode 100644 index 000000000..cd5e54039 --- /dev/null +++ b/tests/test_fx/test_tracer/test_hf_model/utils.py @@ -0,0 +1,33 @@ +from numpy import isin +import torch +from colossalai.fx import ColoTracer +from torch.fx import GraphModule +from torch.utils._pytree import tree_flatten + + +def trace_model_and_compare_output(model, data_gen): + tracer = ColoTracer() + # make sure that the model is traceable + try: + kwargs = data_gen() + meta_args = {k: v.to('meta') for k, v in kwargs.items()} + graph = tracer.trace(root=model, meta_args=meta_args) + except Exception as e: + raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + # check output + inputs = data_gen() + + # must turn on eval mode to ensure the output is consistent + gm.eval() + model.eval() + + # run forward + non_fx_out = model(**inputs) + fx_out = gm(**inputs) + + for k in non_fx_out.keys(): + if torch.is_tensor(fx_out[k]): + assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'