diff --git a/colossalai/trainer/hooks/_mem_tracer_hook.py b/colossalai/trainer/hooks/_mem_tracer_hook.py deleted file mode 100644 index 29c5d9b3c..000000000 --- a/colossalai/trainer/hooks/_mem_tracer_hook.py +++ /dev/null @@ -1,44 +0,0 @@ -from colossalai.registry import HOOKS -from torch import Tensor -from colossalai.trainer.hooks import BaseHook -from colossalai.gemini.memory_tracer import AsyncMemoryMonitor - - -@HOOKS.register_module -class MemTraceHook(BaseHook): - """Save memory stats and pass it to states - This hook is used to record memory usage info, and pass to trainer.states - You can use it as other trainer hook and fetch data from trainer.states['metrics][mode] - """ - - def __init__( - self, - priority: int = 0, - ) -> None: - super().__init__(priority=priority) - self._memory_monitor = AsyncMemoryMonitor() - - def after_hook_is_attached(self, trainer): - # Initialize the data - trainer.states['metrics']['train'] = self._memory_monitor.state_dict - trainer.states['metrics']['test'] = self._memory_monitor.state_dict - - def before_train_iter(self, trainer): - self._memory_monitor.start() - return super().before_train_iter(trainer) - - def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): - self._memory_monitor.finish() - trainer.states['metrics']['train'] = self._memory_monitor.state_dict - trainer.states['metrics']['test'] = self._memory_monitor.state_dict - return super().after_train_iter(trainer, output, label, loss) - - def before_test_iter(self, trainer): - self._memory_monitor.start() - return super().before_test(trainer) - - def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): - self._memory_monitor.finish() - trainer.states['metrics']['train'] = self._memory_monitor.state_dict - trainer.states['metrics']['test'] = self._memory_monitor.state_dict - return super().after_test_iter(trainer, output, label, loss)