[Gemini] make RuntimeMemTracer work correctly (#2096)

pull/2097/head
Jiarui Fang 2022-12-07 16:59:59 +08:00 committed by GitHub
parent fa9d1aea71
commit 4b055351b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 10 deletions

View File

@ -35,13 +35,31 @@ class MemStats(object):
else: else:
raise TypeError raise TypeError
def append_non_model_data(self, device_type: str): def last_model_data(self, device_type: str):
if len(self._model_data_cuda_list) == 0:
return None
if device_type == 'cuda':
return self._model_data_cuda_list[-1]
elif device_type == 'cpu':
return self._model_data_cpu_list[-1]
else:
raise TypeError
def append_non_model_data(self, device_type: str, val=None):
if device_type == 'cuda':
if val is None:
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0: if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
return return
if device_type == 'cuda':
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
else:
self._non_model_data_cuda_list.append(val)
elif device_type == 'cpu': elif device_type == 'cpu':
if val is None:
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
return
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
else:
self._non_model_data_cuda_list.append(val)
else: else:
raise TypeError raise TypeError

View File

@ -76,8 +76,7 @@ class RuntimeMemTracer():
def _post_backward(self): def _post_backward(self):
cuda_volume = self.param_op_hook.mem_monitor.finish() cuda_volume = self.param_op_hook.mem_monitor.finish()
self._memstats.append_model_data('cuda', cuda_volume) self._memstats.append_non_model_data('cuda', cuda_volume - self._memstats.last_model_data('cuda'))
self._memstats.append_non_model_data('cuda')
self.grad_hook.remove_grad_hook() self.grad_hook.remove_grad_hook()
self._restore_params() self._restore_params()

View File

@ -5,7 +5,7 @@ from typing import List
import torch import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.gemini.tensor_utils import alloc_storage, free_storage from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
@ -51,7 +51,7 @@ class GradMemTracerHook():
class ParamMemTracerHook(ColoParamOpHook): class ParamMemTracerHook(ColoParamOpHook):
def __init__(self, memstats, gradstats: GradMemStats) -> None: def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None:
super().__init__() super().__init__()
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
self._memstats = memstats self._memstats = memstats
@ -92,7 +92,9 @@ class ParamMemTracerHook(ColoParamOpHook):
def pre_op(self, params): def pre_op(self, params):
cuda_volume = self.mem_monitor.finish() cuda_volume = self.mem_monitor.finish()
self._memstats.append_model_data('cuda', cuda_volume) last_model_data_val = self._memstats.last_model_data('cuda')
if last_model_data_val is not None:
self._memstats.append_non_model_data('cuda', cuda_volume - last_model_data_val)
self._allocate_params_on_cuda(params) self._allocate_params_on_cuda(params)
self.sample_model_data(params) self.sample_model_data(params)
self.mem_monitor.start() self.mem_monitor.start()