mirror of https://github.com/hpcaitech/ColossalAI
[fx] added unit test for coloproxy (#1119)
* [fx] added unit test for coloproxy * polish code * polish codepull/1121/head
parent
7d14b473f0
commit
16302a5359
|
@ -19,16 +19,16 @@ class ColoProxy(Proxy):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.meta_tensor = None
|
||||
self._meta_tensor = None
|
||||
|
||||
@property
|
||||
def meta_tensor(self):
|
||||
return self.meta_tensor
|
||||
return self._meta_tensor
|
||||
|
||||
@meta_tensor.setter
|
||||
def meta_tensor(self, tensor: torch.Tensor):
|
||||
assert tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
|
||||
self.meta_tensor = tensor
|
||||
assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
|
||||
self._meta_tensor = tensor
|
||||
|
||||
@property
|
||||
def has_meta_tensor(self):
|
||||
|
@ -42,6 +42,19 @@ class ColoProxy(Proxy):
|
|||
self._assert_has_meta()
|
||||
return self.meta_tensor.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
self._assert_has_meta()
|
||||
return self.meta_tensor.shape
|
||||
|
||||
def dim(self):
|
||||
self._assert_has_meta()
|
||||
return self.meta_tensor.dim()
|
||||
|
||||
def size(self, dim: int = None):
|
||||
self._assert_has_meta()
|
||||
return self.meta_tensor.size(dim=dim)
|
||||
|
||||
def __len__(self):
|
||||
self._assert_has_meta()
|
||||
return len(self.meta_tensor)
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import torch
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
|
||||
|
||||
def test_coloproxy():
|
||||
# create a dummy node only for testing purpose
|
||||
model = torch.nn.Linear(10, 10)
|
||||
gm = torch.fx.symbolic_trace(model)
|
||||
node = list(gm.graph.nodes)[0]
|
||||
|
||||
# create proxy
|
||||
proxy = ColoProxy(node=node)
|
||||
proxy.meta_tensor = torch.empty(4, 2, device='meta')
|
||||
|
||||
assert len(proxy) == 4
|
||||
assert proxy.shape[0] == 4 and proxy.shape[1] == 2
|
||||
assert proxy.dim() == 2
|
||||
assert proxy.dtype == torch.float32
|
||||
assert proxy.size(0) == 4
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_coloproxy()
|
Loading…
Reference in New Issue