from colossalai.registry import HOOKS from torch import Tensor from colossalai.trainer.hooks import BaseHook from colossalai.gemini.memory_tracer import AsyncMemoryMonitor @HOOKS.register_module class MemTraceHook(BaseHook): """Save memory stats and pass it to states This hook is used to record memory usage info, and pass to trainer.states You can use it as other trainer hook and fetch data from trainer.states['metrics][mode] """ def __init__( self, priority: int = 0, ) -> None: super().__init__(priority=priority) self._memory_monitor = AsyncMemoryMonitor() def after_hook_is_attached(self, trainer): # Initialize the data trainer.states['metrics']['train'] = self._memory_monitor.state_dict trainer.states['metrics']['test'] = self._memory_monitor.state_dict def before_train_iter(self, trainer): self._memory_monitor.start() return super().before_train_iter(trainer) def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): self._memory_monitor.finish() trainer.states['metrics']['train'] = self._memory_monitor.state_dict trainer.states['metrics']['test'] = self._memory_monitor.state_dict return super().after_train_iter(trainer, output, label, loss) def before_test_iter(self, trainer): self._memory_monitor.start() return super().before_test(trainer) def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): self._memory_monitor.finish() trainer.states['metrics']['train'] = self._memory_monitor.state_dict trainer.states['metrics']['test'] = self._memory_monitor.state_dict return super().after_test_iter(trainer, output, label, loss)