diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index a06876310..d521fe212 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -73,10 +73,15 @@ class MemStatsCollector: # deprecated def record_model_data_volume(self) -> None: - """Sampling model data statistics. + """ + Sampling model data statistics. """ if self._start_flag and not self.use_outside_memstats: - raise NotImplementedError("MemStatsCollector has not implemented record_model_data_volume") + # The following code work for ZeroInitContext, which is deprecated in v0.1.12 + cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] + cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] + self._memstats.append_model_data('cuda', cuda_mem) + self._memstats.append_model_data('cpu', cpu_mem) def sample_overall_data(self) -> None: """