[polish] remove useless file _mem_tracer_hook.py (#1963)

pull/1964/head
Jiarui Fang 2 years ago committed by GitHub
parent c4739a725a
commit 8c66a1d0aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,44 +0,0 @@
from colossalai.registry import HOOKS
from torch import Tensor
from colossalai.trainer.hooks import BaseHook
from colossalai.gemini.memory_tracer import AsyncMemoryMonitor
@HOOKS.register_module
class MemTraceHook(BaseHook):
"""Save memory stats and pass it to states
This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
"""
def __init__(
self,
priority: int = 0,
) -> None:
super().__init__(priority=priority)
self._memory_monitor = AsyncMemoryMonitor()
def after_hook_is_attached(self, trainer):
# Initialize the data
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
def before_train_iter(self, trainer):
self._memory_monitor.start()
return super().before_train_iter(trainer)
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
self._memory_monitor.finish()
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
return super().after_train_iter(trainer, output, label, loss)
def before_test_iter(self, trainer):
self._memory_monitor.start()
return super().before_test(trainer)
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
self._memory_monitor.finish()
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
return super().after_test_iter(trainer, output, label, loss)
Loading…
Cancel
Save