mirror of https://github.com/hpcaitech/ColossalAI
[fx] added unit test for coloproxy (#1119)
* [fx] added unit test for coloproxy * polish code * polish codepull/1121/head
parent
7d14b473f0
commit
16302a5359
|
@ -19,16 +19,16 @@ class ColoProxy(Proxy):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.meta_tensor = None
|
self._meta_tensor = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def meta_tensor(self):
|
def meta_tensor(self):
|
||||||
return self.meta_tensor
|
return self._meta_tensor
|
||||||
|
|
||||||
@meta_tensor.setter
|
@meta_tensor.setter
|
||||||
def meta_tensor(self, tensor: torch.Tensor):
|
def meta_tensor(self, tensor: torch.Tensor):
|
||||||
assert tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
|
assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
|
||||||
self.meta_tensor = tensor
|
self._meta_tensor = tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_meta_tensor(self):
|
def has_meta_tensor(self):
|
||||||
|
@ -42,6 +42,19 @@ class ColoProxy(Proxy):
|
||||||
self._assert_has_meta()
|
self._assert_has_meta()
|
||||||
return self.meta_tensor.dtype
|
return self.meta_tensor.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
self._assert_has_meta()
|
||||||
|
return self.meta_tensor.shape
|
||||||
|
|
||||||
|
def dim(self):
|
||||||
|
self._assert_has_meta()
|
||||||
|
return self.meta_tensor.dim()
|
||||||
|
|
||||||
|
def size(self, dim: int = None):
|
||||||
|
self._assert_has_meta()
|
||||||
|
return self.meta_tensor.size(dim=dim)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
self._assert_has_meta()
|
self._assert_has_meta()
|
||||||
return len(self.meta_tensor)
|
return len(self.meta_tensor)
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
import torch
|
||||||
|
from colossalai.fx.proxy import ColoProxy
|
||||||
|
|
||||||
|
|
||||||
|
def test_coloproxy():
|
||||||
|
# create a dummy node only for testing purpose
|
||||||
|
model = torch.nn.Linear(10, 10)
|
||||||
|
gm = torch.fx.symbolic_trace(model)
|
||||||
|
node = list(gm.graph.nodes)[0]
|
||||||
|
|
||||||
|
# create proxy
|
||||||
|
proxy = ColoProxy(node=node)
|
||||||
|
proxy.meta_tensor = torch.empty(4, 2, device='meta')
|
||||||
|
|
||||||
|
assert len(proxy) == 4
|
||||||
|
assert proxy.shape[0] == 4 and proxy.shape[1] == 2
|
||||||
|
assert proxy.dim() == 2
|
||||||
|
assert proxy.dtype == torch.float32
|
||||||
|
assert proxy.size(0) == 4
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_coloproxy()
|
Loading…
Reference in New Issue