mirror of https://github.com/hpcaitech/ColossalAI
38 lines
775 B
Python
38 lines
775 B
Python
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
|
import torch
|
|
|
|
|
|
def test_mem_collector():
|
|
collector = MemStatsCollector()
|
|
|
|
collector.start_collection()
|
|
|
|
a = torch.randn(10).cuda()
|
|
|
|
# sampling at time 0
|
|
collector.sample_memstats()
|
|
|
|
m_a = torch.randn(10).cuda()
|
|
b = torch.randn(10).cuda()
|
|
|
|
# sampling at time 1
|
|
collector.sample_memstats()
|
|
|
|
a = b
|
|
|
|
# sampling at time 2
|
|
collector.sample_memstats()
|
|
|
|
collector.finish_collection()
|
|
collector.reset_sampling_cnter()
|
|
|
|
# do nothing after collection, just advance sampling cnter
|
|
collector.sample_memstats()
|
|
collector.sample_memstats()
|
|
|
|
print(collector.overall_mem_stats('cuda'))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_mem_collector()
|