diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index 5224deaeb..a6c02927c 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -19,16 +19,16 @@ class ColoProxy(Proxy): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.meta_tensor = None + self._meta_tensor = None @property def meta_tensor(self): - return self.meta_tensor + return self._meta_tensor @meta_tensor.setter def meta_tensor(self, tensor: torch.Tensor): - assert tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor' - self.meta_tensor = tensor + assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor' + self._meta_tensor = tensor @property def has_meta_tensor(self): @@ -42,6 +42,19 @@ class ColoProxy(Proxy): self._assert_has_meta() return self.meta_tensor.dtype + @property + def shape(self): + self._assert_has_meta() + return self.meta_tensor.shape + + def dim(self): + self._assert_has_meta() + return self.meta_tensor.dim() + + def size(self, dim: int = None): + self._assert_has_meta() + return self.meta_tensor.size(dim=dim) + def __len__(self): self._assert_has_meta() return len(self.meta_tensor) diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py new file mode 100644 index 000000000..a1c75d168 --- /dev/null +++ b/tests/test_fx/test_coloproxy.py @@ -0,0 +1,23 @@ +import torch +from colossalai.fx.proxy import ColoProxy + + +def test_coloproxy(): + # create a dummy node only for testing purpose + model = torch.nn.Linear(10, 10) + gm = torch.fx.symbolic_trace(model) + node = list(gm.graph.nodes)[0] + + # create proxy + proxy = ColoProxy(node=node) + proxy.meta_tensor = torch.empty(4, 2, device='meta') + + assert len(proxy) == 4 + assert proxy.shape[0] == 4 and proxy.shape[1] == 2 + assert proxy.dim() == 2 + assert proxy.dtype == torch.float32 + assert proxy.size(0) == 4 + + +if __name__ == '__main__': + test_coloproxy() \ No newline at end of file