2022-11-18 02:53:42 +00:00
|
|
|
from colossalai.gemini.ophooks import register_ophooks_recursively
|
|
|
|
from colossalai.gemini.ophooks.mem_trace_hook import MemTracerOpHook
|
|
|
|
|
|
|
|
__all__ = ['MemtracerWrapper']
|
|
|
|
|
|
|
|
|
|
|
|
class _Wrapper():
|
|
|
|
|
|
|
|
def __init__(self, model, ophook_list):
|
|
|
|
self._ophook_list = ophook_list
|
|
|
|
self._model = model
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
return self._model(*args, **kwargs)
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
return self._model.forward(*args, **kwargs)
|
|
|
|
|
|
|
|
def backward(self, loss):
|
|
|
|
loss.backward()
|
|
|
|
for ophook in self._ophook_list:
|
|
|
|
ophook.post_iter()
|
|
|
|
|
|
|
|
def save_results(self, filename):
|
|
|
|
for ophook in self._ophook_list:
|
|
|
|
ophook.save_results(filename)
|
|
|
|
|
|
|
|
def show_mem_stats(self):
|
|
|
|
self._ophook_list[0].show_mem_stats()
|
|
|
|
|
2022-11-18 06:58:28 +00:00
|
|
|
def named_buffers(self):
|
|
|
|
return self._model.named_buffers()
|
|
|
|
|
2022-11-18 02:53:42 +00:00
|
|
|
|
|
|
|
def MemtracerWrapper(model):
|
|
|
|
ophook_list = [MemTracerOpHook()]
|
|
|
|
register_ophooks_recursively(model, ophook_list)
|
|
|
|
engine = _Wrapper(model, ophook_list)
|
|
|
|
return engine
|