2022-03-29 04:48:34 +00:00
|
|
|
from colossalai.registry import HOOKS
|
|
|
|
from torch import Tensor
|
|
|
|
from colossalai.trainer.hooks import BaseHook
|
2022-04-19 02:13:08 +00:00
|
|
|
from colossalai.gemini.memory_tracer import AsyncMemoryMonitor
|
2022-04-13 02:50:54 +00:00
|
|
|
|
2022-03-29 04:48:34 +00:00
|
|
|
|
|
|
|
@HOOKS.register_module
|
|
|
|
class MemTraceHook(BaseHook):
|
|
|
|
"""Save memory stats and pass it to states
|
|
|
|
This hook is used to record memory usage info, and pass to trainer.states
|
|
|
|
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
|
|
|
|
"""
|
2022-04-13 02:50:54 +00:00
|
|
|
|
2022-03-29 04:48:34 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
priority: int = 0,
|
|
|
|
) -> None:
|
|
|
|
super().__init__(priority=priority)
|
|
|
|
self._memory_monitor = AsyncMemoryMonitor()
|
|
|
|
|
|
|
|
def after_hook_is_attached(self, trainer):
|
|
|
|
# Initialize the data
|
|
|
|
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
|
|
|
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
|
|
|
|
|
|
|
def before_train_iter(self, trainer):
|
|
|
|
self._memory_monitor.start()
|
|
|
|
return super().before_train_iter(trainer)
|
|
|
|
|
|
|
|
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
|
|
|
self._memory_monitor.finish()
|
|
|
|
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
|
|
|
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
|
|
|
return super().after_train_iter(trainer, output, label, loss)
|
|
|
|
|
|
|
|
def before_test_iter(self, trainer):
|
|
|
|
self._memory_monitor.start()
|
|
|
|
return super().before_test(trainer)
|
2022-04-13 02:50:54 +00:00
|
|
|
|
2022-03-29 04:48:34 +00:00
|
|
|
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
|
|
|
self._memory_monitor.finish()
|
|
|
|
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
|
|
|
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
2022-04-13 02:50:54 +00:00
|
|
|
return super().after_test_iter(trainer, output, label, loss)
|