from torch import nn import torch from colossalai.tensor import ColoTensor from colossalai.tensor.graph import GraphContext import gc class SimpleNet(nn.Module): def __init__(self) -> None: super().__init__() self.proj1 = nn.Linear(4, 8) self.proj2 = nn.Linear(8, 4) self.proj3 = nn.Linear(4, 4) self.proj4 = nn.Linear(4, 4) def forward(self, x): x = self.proj1(x) x = self.proj2(x) x = self.proj3(x) x = self.proj4(x) return x def _visit_graph(start_node): if start_node is None: return start_node.print() post_node_list = start_node.post_nodes for node in post_node_list: _visit_graph(node) def _get_tensors(): for obj in gc.get_objects(): try: if torch.is_tensor(obj): yield obj except Exception as e: print('A trivial exception occured: {}'.format(e)) def _count_tensors(): cnt = 0 for t in _get_tensors(): cnt += 1 return cnt def count_tensors(use_colossal): model = SimpleNet() model.eval() with torch.no_grad(): if use_colossal: colo_input = ColoTensor.init_from_torch_tensor(torch.randn(4)) graph_ctx = GraphContext() with graph_ctx: output = model(colo_input) output = model(colo_input) ret = _count_tensors() _visit_graph(graph_ctx.graph_nodes[0]) del graph_ctx return ret else: input_t = torch.randn(4) output = model(input_t) output = model(input_t) return _count_tensors() def test_check_activation_tensors(): assert count_tensors(False) == count_tensors(True) if __name__ == "__main__": count_tensors(True)