You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_tensor/test_graph.py

85 lines
1.8 KiB

import pytest
from torch import nn
import torch
from colossalai.tensor import ColoTensor
from colossalai.tensor.graph import GraphContext
import gc
class SimpleNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.proj1 = nn.Linear(4, 8)
self.proj2 = nn.Linear(8, 4)
self.proj3 = nn.Linear(4, 4)
self.proj4 = nn.Linear(4, 4)
def forward(self, x):
x = self.proj1(x)
x = self.proj2(x)
x = self.proj3(x)
x = self.proj4(x)
return x
def _visit_graph(start_node):
if start_node is None:
return
start_node.print()
post_node_list = start_node.post_nodes
for node in post_node_list:
_visit_graph(node)
def _get_tensors():
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
yield obj
except Exception as e:
print('A trivial exception occured: {}'.format(e))
def _count_tensors():
cnt = 0
for t in _get_tensors():
cnt += 1
return cnt
def count_tensors(use_colossal):
model = SimpleNet()
model.eval()
with torch.no_grad():
if use_colossal:
colo_input = ColoTensor.from_torch_tensor(torch.randn(4))
graph_ctx = GraphContext()
with graph_ctx:
output = model(colo_input)
output = model(colo_input)
ret = _count_tensors()
_visit_graph(graph_ctx.graph_nodes[0])
del graph_ctx
return ret
else:
input_t = torch.randn(4)
output = model(input_t)
output = model(input_t)
return _count_tensors()
@pytest.mark.skip
# FIXME(ver217)
def test_check_activation_tensors():
assert count_tensors(False) == count_tensors(True)
if __name__ == "__main__":
count_tensors(True)