From 8c66a1d0aaca7ec37f48c1c19dfbb4495e6ec1ca Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 16 Nov 2022 15:55:10 +0800 Subject: [PATCH] [polish] remove useless file _mem_tracer_hook.py (#1963) --- colossalai/trainer/hooks/_mem_tracer_hook.py | 44 -------------------- 1 file changed, 44 deletions(-) delete mode 100644 colossalai/trainer/hooks/_mem_tracer_hook.py diff --git a/colossalai/trainer/hooks/_mem_tracer_hook.py b/colossalai/trainer/hooks/_mem_tracer_hook.py deleted file mode 100644 index 29c5d9b3c..000000000 --- a/colossalai/trainer/hooks/_mem_tracer_hook.py +++ /dev/null @@ -1,44 +0,0 @@ -from colossalai.registry import HOOKS -from torch import Tensor -from colossalai.trainer.hooks import BaseHook -from colossalai.gemini.memory_tracer import AsyncMemoryMonitor - - -@HOOKS.register_module -class MemTraceHook(BaseHook): - """Save memory stats and pass it to states - This hook is used to record memory usage info, and pass to trainer.states - You can use it as other trainer hook and fetch data from trainer.states['metrics][mode] - """ - - def __init__( - self, - priority: int = 0, - ) -> None: - super().__init__(priority=priority) - self._memory_monitor = AsyncMemoryMonitor() - - def after_hook_is_attached(self, trainer): - # Initialize the data - trainer.states['metrics']['train'] = self._memory_monitor.state_dict - trainer.states['metrics']['test'] = self._memory_monitor.state_dict - - def before_train_iter(self, trainer): - self._memory_monitor.start() - return super().before_train_iter(trainer) - - def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): - self._memory_monitor.finish() - trainer.states['metrics']['train'] = self._memory_monitor.state_dict - trainer.states['metrics']['test'] = self._memory_monitor.state_dict - return super().after_train_iter(trainer, output, label, loss) - - def before_test_iter(self, trainer): - self._memory_monitor.start() - return super().before_test(trainer) - - def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): - self._memory_monitor.finish() - trainer.states['metrics']['train'] = self._memory_monitor.state_dict - trainer.states['metrics']['test'] = self._memory_monitor.state_dict - return super().after_test_iter(trainer, output, label, loss)