ColossalAI/colossalai/gemini/memory_tracer/module_tracer_wrapper.py

40 lines
1.0 KiB
Python
Raw Normal View History

from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.gemini.ophooks.mem_trace_hook import MemTracerOpHook
__all__ = ['MemtracerWrapper']
class _Wrapper():
def __init__(self, model, ophook_list):
self._ophook_list = ophook_list
self._model = model
def __call__(self, *args, **kwargs):
return self._model(*args, **kwargs)
def forward(self, *args, **kwargs):
return self._model.forward(*args, **kwargs)
def backward(self, loss):
loss.backward()
for ophook in self._ophook_list:
ophook.post_iter()
def save_results(self, filename):
for ophook in self._ophook_list:
ophook.save_results(filename)
def show_mem_stats(self):
self._ophook_list[0].show_mem_stats()
def named_buffers(self):
return self._model.named_buffers()
def MemtracerWrapper(model):
ophook_list = [MemTracerOpHook()]
register_ophooks_recursively(model, ophook_list)
engine = _Wrapper(model, ophook_list)
return engine