Browse Source

[fx] added unit test for coloproxy (#1119)

* [fx] added unit test for coloproxy

* polish code

* polish code
pull/1121/head
Frank Lee 2 years ago committed by GitHub
parent
commit
16302a5359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 21
      colossalai/fx/proxy.py
  2. 23
      tests/test_fx/test_coloproxy.py

21
colossalai/fx/proxy.py

@ -19,16 +19,16 @@ class ColoProxy(Proxy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.meta_tensor = None
self._meta_tensor = None
@property
def meta_tensor(self):
return self.meta_tensor
return self._meta_tensor
@meta_tensor.setter
def meta_tensor(self, tensor: torch.Tensor):
assert tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
self.meta_tensor = tensor
assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
self._meta_tensor = tensor
@property
def has_meta_tensor(self):
@ -42,6 +42,19 @@ class ColoProxy(Proxy):
self._assert_has_meta()
return self.meta_tensor.dtype
@property
def shape(self):
self._assert_has_meta()
return self.meta_tensor.shape
def dim(self):
self._assert_has_meta()
return self.meta_tensor.dim()
def size(self, dim: int = None):
self._assert_has_meta()
return self.meta_tensor.size(dim=dim)
def __len__(self):
self._assert_has_meta()
return len(self.meta_tensor)

23
tests/test_fx/test_coloproxy.py

@ -0,0 +1,23 @@
import torch
from colossalai.fx.proxy import ColoProxy
def test_coloproxy():
# create a dummy node only for testing purpose
model = torch.nn.Linear(10, 10)
gm = torch.fx.symbolic_trace(model)
node = list(gm.graph.nodes)[0]
# create proxy
proxy = ColoProxy(node=node)
proxy.meta_tensor = torch.empty(4, 2, device='meta')
assert len(proxy) == 4
assert proxy.shape[0] == 4 and proxy.shape[1] == 2
assert proxy.dim() == 2
assert proxy.dtype == torch.float32
assert proxy.size(0) == 4
if __name__ == '__main__':
test_coloproxy()
Loading…
Cancel
Save