ColossalAI/colossalai/utils/memory_tracer/test_memstats_collector.py

44 lines
1.0 KiB
Python

from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
import torch
def test_mem_collector():
collector = MemStatsCollector()
collector.start_collection()
a = torch.randn(10).cuda()
# sampling at time 0
collector.sample_memstats()
m_a = torch.randn(10).cuda()
GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
b = torch.randn(10).cuda()
# sampling at time 1
collector.sample_memstats()
a = b
# sampling at time 2
collector.sample_memstats()
collector.finish_collection()
collector.reset_sampling_cnter()
# do nothing after collection, just advance sampling cnter
collector.sample_memstats()
collector.sample_memstats()
cuda_use, overall_use = collector.fetch_memstats()
print(cuda_use, overall_use)
print(collector._model_data_cuda)
print(collector._overall_cuda)
if __name__ == '__main__':
test_mem_collector()