[Gemini] remove GLOBAL_CUDA_MEM_INFO (#2090)

pull/2091/head^2
Jiarui Fang 2022-12-06 22:10:47 +08:00 committed by GitHub
parent 25abae6d7f
commit 28e55c2530
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 42 deletions

View File

@ -1,6 +1,8 @@
from colossalai.context.singleton_meta import SingletonMeta from typing import Optional, Tuple
import torch import torch
from typing import Tuple, Optional
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
@ -106,15 +108,4 @@ class ModelDataTracer(metaclass=SingletonMeta):
return self._get_mem_usage() return self._get_mem_usage()
class CudaMemInfo(metaclass=SingletonMeta):
def __init__(self) -> None:
self.model_data_list = []
self.non_model_data_list = []
self.unreleased_grad_flag = {}
self.unreleased_grad_volume = 0
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer() GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
GLOBAL_CUDA_MEM_INFO = CudaMemInfo()

View File

@ -1,8 +1,7 @@
import torch.nn import torch.nn
from colossalai.gemini.memory_tracer import MemStats from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
@ -25,9 +24,10 @@ class RuntimeMemTracer():
super().__init__() super().__init__()
self.module = module self.module = module
self.dtype = dtype self.dtype = dtype
self._gradstat = GradMemStats()
self._memstats = MemStats() self._memstats = MemStats()
self.param_op_hook = ParamMemTracerHook(self._memstats) self.param_op_hook = ParamMemTracerHook(self._memstats, self._gradstat)
self.grad_hook = GradMemTracerHook(module) self.grad_hook = GradMemTracerHook(self._gradstat)
self.cpu_param_data_dict = {} self.cpu_param_data_dict = {}
for p in module.parameters(): for p in module.parameters():
@ -58,7 +58,7 @@ class RuntimeMemTracer():
def _pre_forward(self): def _pre_forward(self):
self._clear_cuda_mem_info() self._clear_cuda_mem_info()
self._backup_params() self._backup_params()
self.grad_hook.register_grad_hook() self.grad_hook.register_grad_hook(self.module)
self.param_op_hook.mem_monitor.start() self.param_op_hook.mem_monitor.start()
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@ -78,17 +78,12 @@ class RuntimeMemTracer():
cuda_volume = self.param_op_hook.mem_monitor.finish() cuda_volume = self.param_op_hook.mem_monitor.finish()
self._memstats.append_model_data('cuda', cuda_volume) self._memstats.append_model_data('cuda', cuda_volume)
self._memstats.append_non_model_data('cuda') self._memstats.append_non_model_data('cuda')
# last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1]
# GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data)
self.grad_hook.remove_grad_hook() self.grad_hook.remove_grad_hook()
self._restore_params() self._restore_params()
def _clear_cuda_mem_info(self): def _clear_cuda_mem_info(self):
# GLOBAL_CUDA_MEM_INFO.model_data_list.clear()
# GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear()
self._memstats.clear() self._memstats.clear()
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag.clear() self._gradstat.clear()
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume = 0
def _cast_buffers_to_cuda_dtype(self): def _cast_buffers_to_cuda_dtype(self):
for buffer in self.module.buffers(): for buffer in self.module.buffers():

View File

@ -6,7 +6,6 @@ from typing import List
import torch import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
from colossalai.gemini.tensor_utils import alloc_storage, free_storage from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
@ -16,23 +15,34 @@ class TrainingPhase(Enum):
BACKWARD = 1 BACKWARD = 1
class GradMemStats():
def __init__(self) -> None:
self.unreleased_grad_flag = {}
self.unreleased_grad_volume = 0
def clear(self):
self.unreleased_grad_flag.clear()
self.unreleased_grad_volume = 0
class GradMemTracerHook(): class GradMemTracerHook():
def __init__(self, module: torch.nn.Module): def __init__(self, grad_stats: GradMemStats):
self.module = module
self.grad_hook_list = [] self.grad_hook_list = []
self._grad_stats = grad_stats
def grad_handle(self, p, grad): def grad_handle(self, p, grad):
assert GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] assert self._grad_stats.unreleased_grad_flag[p]
free_storage(grad) free_storage(grad)
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume -= grad.numel() * grad.element_size() self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size()
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False self._grad_stats.unreleased_grad_flag[p] = False
def register_grad_hook(self): def register_grad_hook(self, module: torch.nn.Module):
for p in self.module.parameters(): for p in module.parameters():
if p.requires_grad: if p.requires_grad:
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p))) self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False self._grad_stats.unreleased_grad_flag[p] = False
def remove_grad_hook(self): def remove_grad_hook(self):
for hook in self.grad_hook_list: for hook in self.grad_hook_list:
@ -41,10 +51,11 @@ class GradMemTracerHook():
class ParamMemTracerHook(ColoParamOpHook): class ParamMemTracerHook(ColoParamOpHook):
def __init__(self, memstats) -> None: def __init__(self, memstats, gradstats: GradMemStats) -> None:
super().__init__() super().__init__()
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
self._memstats = memstats self._memstats = memstats
self._grad_stats = gradstats
self.mem_monitor = SyncCudaMemoryMonitor() self.mem_monitor = SyncCudaMemoryMonitor()
def _free_cuda_params(self, params): def _free_cuda_params(self, params):
@ -67,24 +78,21 @@ class ParamMemTracerHook(ColoParamOpHook):
alloc_storage(p.data) alloc_storage(p.data)
def sample_model_data(self, params): def sample_model_data(self, params):
data_volume = GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume data_volume = self._grad_stats.unreleased_grad_volume
for p in params: for p in params:
cur_model_data_volume = p.data.numel() * p.data.element_size() cur_model_data_volume = p.data.numel() * p.data.element_size()
data_volume += cur_model_data_volume data_volume += cur_model_data_volume
if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad: if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
# add param.grad, actually param.grad is None in this time # add param.grad, actually param.grad is None in this time
data_volume += cur_model_data_volume data_volume += cur_model_data_volume
if not GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]: if not self._grad_stats.unreleased_grad_flag[p]:
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume += cur_model_data_volume self._grad_stats.unreleased_grad_volume += cur_model_data_volume
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = True self._grad_stats.unreleased_grad_flag[p] = True
# GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume)
self._memstats.append_model_data('cuda', data_volume) self._memstats.append_model_data('cuda', data_volume)
def pre_op(self, params): def pre_op(self, params):
cuda_volume = self.mem_monitor.finish() cuda_volume = self.mem_monitor.finish()
self._memstats.append_model_data('cuda', cuda_volume) self._memstats.append_model_data('cuda', cuda_volume)
# if len(GLOBAL_CUDA_MEM_INFO.model_data_list):
# GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1])
self._allocate_params_on_cuda(params) self._allocate_params_on_cuda(params)
self.sample_model_data(params) self.sample_model_data(params)
self.mem_monitor.start() self.mem_monitor.start()