from async_memtracer import AsyncMemoryMonitor
import torch

if __name__ == '__main__':
    async_mem_monitor = AsyncMemoryMonitor()
    input = torch.randn(2, 20).cuda()
    OP1 = torch.nn.Linear(20, 30).cuda()
    OP2 = torch.nn.Linear(30, 40).cuda()

    async_mem_monitor.start()
    output = OP1(input)
    async_mem_monitor.finish()
    async_mem_monitor.start()
    output = OP2(output)
    async_mem_monitor.finish()
    async_mem_monitor.save('log.pkl')