mirror of https://github.com/hpcaitech/ColossalAI
Frank Lee
2 years ago
committed by
GitHub
2 changed files with 40 additions and 4 deletions
@ -0,0 +1,23 @@
|
||||
import torch |
||||
from colossalai.fx.proxy import ColoProxy |
||||
|
||||
|
||||
def test_coloproxy(): |
||||
# create a dummy node only for testing purpose |
||||
model = torch.nn.Linear(10, 10) |
||||
gm = torch.fx.symbolic_trace(model) |
||||
node = list(gm.graph.nodes)[0] |
||||
|
||||
# create proxy |
||||
proxy = ColoProxy(node=node) |
||||
proxy.meta_tensor = torch.empty(4, 2, device='meta') |
||||
|
||||
assert len(proxy) == 4 |
||||
assert proxy.shape[0] == 4 and proxy.shape[1] == 2 |
||||
assert proxy.dim() == 2 |
||||
assert proxy.dtype == torch.float32 |
||||
assert proxy.size(0) == 4 |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
test_coloproxy() |
Loading…
Reference in new issue