mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] remove GLOBAL_CUDA_MEM_INFO (#2090)
parent
25abae6d7f
commit
28e55c2530
|
@ -1,6 +1,8 @@
|
|||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
|
@ -106,15 +108,4 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
|||
return self._get_mem_usage()
|
||||
|
||||
|
||||
class CudaMemInfo(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model_data_list = []
|
||||
self.non_model_data_list = []
|
||||
self.unreleased_grad_flag = {}
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||
|
||||
GLOBAL_CUDA_MEM_INFO = CudaMemInfo()
|
|
@ -1,8 +1,7 @@
|
|||
import torch.nn
|
||||
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook
|
||||
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
|
||||
|
@ -25,9 +24,10 @@ class RuntimeMemTracer():
|
|||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self._gradstat = GradMemStats()
|
||||
self._memstats = MemStats()
|
||||
self.param_op_hook = ParamMemTracerHook(self._memstats)
|
||||
self.grad_hook = GradMemTracerHook(module)
|
||||
self.param_op_hook = ParamMemTracerHook(self._memstats, self._gradstat)
|
||||
self.grad_hook = GradMemTracerHook(self._gradstat)
|
||||
self.cpu_param_data_dict = {}
|
||||
|
||||
for p in module.parameters():
|
||||
|
@ -58,7 +58,7 @@ class RuntimeMemTracer():
|
|||
def _pre_forward(self):
|
||||
self._clear_cuda_mem_info()
|
||||
self._backup_params()
|
||||
self.grad_hook.register_grad_hook()
|
||||
self.grad_hook.register_grad_hook(self.module)
|
||||
self.param_op_hook.mem_monitor.start()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
@ -78,17 +78,12 @@ class RuntimeMemTracer():
|
|||
cuda_volume = self.param_op_hook.mem_monitor.finish()
|
||||
self._memstats.append_model_data('cuda', cuda_volume)
|
||||
self._memstats.append_non_model_data('cuda')
|
||||
# last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1]
|
||||
# GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data)
|
||||
self.grad_hook.remove_grad_hook()
|
||||
self._restore_params()
|
||||
|
||||
def _clear_cuda_mem_info(self):
|
||||
# GLOBAL_CUDA_MEM_INFO.model_data_list.clear()
|
||||
# GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear()
|
||||
self._memstats.clear()
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag.clear()
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume = 0
|
||||
self._gradstat.clear()
|
||||
|
||||
def _cast_buffers_to_cuda_dtype(self):
|
||||
for buffer in self.module.buffers():
|
||||
|
|
|
@ -6,7 +6,6 @@ from typing import List
|
|||
import torch
|
||||
|
||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
|
@ -16,23 +15,34 @@ class TrainingPhase(Enum):
|
|||
BACKWARD = 1
|
||||
|
||||
|
||||
class GradMemStats():
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.unreleased_grad_flag = {}
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
def clear(self):
|
||||
self.unreleased_grad_flag.clear()
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
|
||||
class GradMemTracerHook():
|
||||
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
self.module = module
|
||||
def __init__(self, grad_stats: GradMemStats):
|
||||
self.grad_hook_list = []
|
||||
self._grad_stats = grad_stats
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
assert GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]
|
||||
assert self._grad_stats.unreleased_grad_flag[p]
|
||||
free_storage(grad)
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume -= grad.numel() * grad.element_size()
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False
|
||||
self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size()
|
||||
self._grad_stats.unreleased_grad_flag[p] = False
|
||||
|
||||
def register_grad_hook(self):
|
||||
for p in self.module.parameters():
|
||||
def register_grad_hook(self, module: torch.nn.Module):
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False
|
||||
self._grad_stats.unreleased_grad_flag[p] = False
|
||||
|
||||
def remove_grad_hook(self):
|
||||
for hook in self.grad_hook_list:
|
||||
|
@ -41,10 +51,11 @@ class GradMemTracerHook():
|
|||
|
||||
class ParamMemTracerHook(ColoParamOpHook):
|
||||
|
||||
def __init__(self, memstats) -> None:
|
||||
def __init__(self, memstats, gradstats: GradMemStats) -> None:
|
||||
super().__init__()
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
self._memstats = memstats
|
||||
self._grad_stats = gradstats
|
||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
||||
|
||||
def _free_cuda_params(self, params):
|
||||
|
@ -67,24 +78,21 @@ class ParamMemTracerHook(ColoParamOpHook):
|
|||
alloc_storage(p.data)
|
||||
|
||||
def sample_model_data(self, params):
|
||||
data_volume = GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume
|
||||
data_volume = self._grad_stats.unreleased_grad_volume
|
||||
for p in params:
|
||||
cur_model_data_volume = p.data.numel() * p.data.element_size()
|
||||
data_volume += cur_model_data_volume
|
||||
if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
|
||||
# add param.grad, actually param.grad is None in this time
|
||||
data_volume += cur_model_data_volume
|
||||
if not GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]:
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume += cur_model_data_volume
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = True
|
||||
# GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume)
|
||||
if not self._grad_stats.unreleased_grad_flag[p]:
|
||||
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
|
||||
self._grad_stats.unreleased_grad_flag[p] = True
|
||||
self._memstats.append_model_data('cuda', data_volume)
|
||||
|
||||
def pre_op(self, params):
|
||||
cuda_volume = self.mem_monitor.finish()
|
||||
self._memstats.append_model_data('cuda', cuda_volume)
|
||||
# if len(GLOBAL_CUDA_MEM_INFO.model_data_list):
|
||||
# GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1])
|
||||
self._allocate_params_on_cuda(params)
|
||||
self.sample_model_data(params)
|
||||
self.mem_monitor.start()
|
||||
|
|
Loading…
Reference in New Issue