mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] fix grad unreleased issue and param recovery issue (#2052)
parent
edf4cd46c5
commit
38ea4ba1bd
|
@ -106,4 +106,15 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
|||
return self._get_mem_usage()
|
||||
|
||||
|
||||
class CudaMemInfo(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model_data_list = []
|
||||
self.non_model_data_list = []
|
||||
self.unreleased_grad_flag = {}
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||
|
||||
GLOBAL_CUDA_MEM_INFO = CudaMemInfo()
|
|
@ -1,11 +1,9 @@
|
|||
import torch.nn
|
||||
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook
|
||||
from colossalai.gemini.tensor_utils import free_storage
|
||||
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook, GradHook
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from functools import partial
|
||||
|
||||
|
||||
__all__ = ['ParamTracerWrapper']
|
||||
|
||||
|
@ -15,22 +13,33 @@ class ParamTracerWrapper():
|
|||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.param_op_hook = ParamTracerHook(dtype)
|
||||
self.param_op_hook = ParamTracerHook()
|
||||
self.grad_hook = GradHook(module)
|
||||
self.cpu_param_data_dict = {}
|
||||
|
||||
for p in module.parameters():
|
||||
p.data = p.data.to(dtype)
|
||||
if p.requires_grad:
|
||||
p.register_hook(partial(self.grad_handle))
|
||||
|
||||
self._cast_buffers_to_cuda_dtype()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def grad_handle(self, grad):
|
||||
free_storage(grad)
|
||||
def _save_param_data_on_cpu(self):
|
||||
for p in self.module.parameters():
|
||||
self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu")
|
||||
self.cpu_param_data_dict[p].copy_(p.data)
|
||||
|
||||
def _restore_param_data(self):
|
||||
for p in self.module.parameters():
|
||||
p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad)
|
||||
p.data.copy_(self.cpu_param_data_dict[p])
|
||||
self.cpu_param_data_dict.clear()
|
||||
|
||||
def _pre_forward(self):
|
||||
self._clear_cuda_mem_info()
|
||||
self._save_param_data_on_cpu()
|
||||
self.grad_hook.register_grad_hook()
|
||||
self.param_op_hook.mem_monitor.start()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
@ -48,8 +57,16 @@ class ParamTracerWrapper():
|
|||
|
||||
def _post_backward(self):
|
||||
cuda_volume = self.param_op_hook.mem_monitor.finish()
|
||||
last_model_data = self.param_op_hook._model_data_list[-1]
|
||||
self.param_op_hook._non_model_data_list.append(cuda_volume - last_model_data)
|
||||
last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1]
|
||||
GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data)
|
||||
self.grad_hook.remove_grad_hook()
|
||||
self._restore_param_data()
|
||||
|
||||
def _clear_cuda_mem_info(self):
|
||||
GLOBAL_CUDA_MEM_INFO.model_data_list.clear()
|
||||
GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear()
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag.clear()
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume = 0
|
||||
|
||||
def _cast_buffers_to_cuda_dtype(self):
|
||||
for buffer in self.module.buffers():
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch
|
|||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||
from colossalai.gemini.tensor_utils import free_storage, alloc_storage
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
|
@ -15,42 +16,69 @@ class TrainingPhase(Enum):
|
|||
BACKWARD = 1
|
||||
|
||||
|
||||
class GradHook():
|
||||
def __init__(self, module: torch.nn.Module):
|
||||
self.module = module
|
||||
self.grad_hook_list = []
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
assert GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]
|
||||
free_storage(grad)
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume -= grad.numel() * grad.element_size()
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False
|
||||
|
||||
def register_grad_hook(self):
|
||||
for p in self.module.parameters():
|
||||
if p.requires_grad:
|
||||
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False
|
||||
|
||||
def remove_grad_hook(self):
|
||||
for hook in self.grad_hook_list:
|
||||
hook.remove()
|
||||
|
||||
|
||||
class ParamTracerHook(ParamOpHook):
|
||||
|
||||
def __init__(self, dtype: torch.dtype = torch.half) -> None:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
||||
self._non_model_data_list = []
|
||||
self._model_data_list = []
|
||||
self.dtype = dtype
|
||||
|
||||
def _free_cuda_params(self, params):
|
||||
for p in params:
|
||||
if p.data.device.type == "cpu":
|
||||
raise NotImplementedError("Only free cuda memory")
|
||||
free_storage(p.data)
|
||||
|
||||
def _allocate_params_on_cuda(self, params):
|
||||
for p in params:
|
||||
cur_dev = p.data.device.type
|
||||
if cur_dev == "cpu":
|
||||
# p.data = p.data.to("cuda")
|
||||
p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype)
|
||||
if p.grad is not None and p.grad.device.type == "cpu":
|
||||
raise NotImplementedError("Only run in forward propagation")
|
||||
p.data = torch.empty(p.data.shape, device="cuda", dtype=p.data.dtype,
|
||||
requires_grad=p.data.requires_grad)
|
||||
elif cur_dev == "cuda":
|
||||
alloc_storage(p.data)
|
||||
|
||||
def sample_model_data(self, params):
|
||||
data_volume = 0
|
||||
data_volume = GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume
|
||||
for p in params:
|
||||
data_volume += p.data.numel() * p.data.element_size()
|
||||
if self._training_phase == TrainingPhase.BACKWARD:
|
||||
# add param.grad, actually param.grad is None in this time
|
||||
data_volume *= 2
|
||||
self._model_data_list.append(data_volume)
|
||||
cur_model_data_volume = p.data.numel() * p.data.element_size()
|
||||
data_volume += cur_model_data_volume
|
||||
if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
|
||||
# add param.grad, actually param.grad is None in this time
|
||||
data_volume += cur_model_data_volume
|
||||
if not GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]:
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume += cur_model_data_volume
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = True
|
||||
GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume)
|
||||
|
||||
def pre_op(self, params):
|
||||
cuda_volume = self.mem_monitor.finish()
|
||||
if len(self._model_data_list):
|
||||
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
|
||||
if len(GLOBAL_CUDA_MEM_INFO.model_data_list):
|
||||
GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1])
|
||||
self._allocate_params_on_cuda(params)
|
||||
self.sample_model_data(params)
|
||||
self.mem_monitor.start()
|
||||
|
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamTracerWrapper
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
@ -35,9 +36,9 @@ def run_param_wrapper_testing():
|
|||
|
||||
run_fwd_bwd(model, data, label, criterion, False)
|
||||
|
||||
cuda_non_model_data_list = np.array(model.param_op_hook._non_model_data_list) / 1024 ** 2
|
||||
cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024 ** 2
|
||||
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
||||
# print(model.param_op_hook._non_model_data_list)
|
||||
# print(GLOBAL_CUDA_MEM_INFO.non_model_data_list)
|
||||
|
||||
del model
|
||||
|
||||
|
|
Loading…
Reference in New Issue