mirror of https://github.com/hpcaitech/ColossalAI
44 lines
1.0 KiB
Python
44 lines
1.0 KiB
Python
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
|
import torch
|
|
|
|
|
|
def test_mem_collector():
|
|
collector = MemStatsCollector()
|
|
|
|
collector.start_collection()
|
|
|
|
a = torch.randn(10).cuda()
|
|
|
|
# sampling at time 0
|
|
collector.sample_memstats()
|
|
|
|
m_a = torch.randn(10).cuda()
|
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
|
|
b = torch.randn(10).cuda()
|
|
|
|
# sampling at time 1
|
|
collector.sample_memstats()
|
|
|
|
a = b
|
|
|
|
# sampling at time 2
|
|
collector.sample_memstats()
|
|
|
|
collector.finish_collection()
|
|
collector.reset_sampling_cnter()
|
|
|
|
# do nothing after collection, just advance sampling cnter
|
|
collector.sample_memstats()
|
|
collector.sample_memstats()
|
|
|
|
cuda_use, overall_use = collector.fetch_memstats()
|
|
print(cuda_use, overall_use)
|
|
|
|
print(collector._model_data_cuda)
|
|
print(collector._overall_cuda)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_mem_collector()
|