Browse Source

[Gemini] free and allocate cuda memory by tensor.storage, add grad hook (#2040)

pull/2047/head
Zihao 2 years ago committed by GitHub
parent
commit
6a9158f1fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 12
      colossalai/gemini/memory_tracer/param_tracer_wrapper.py
  2. 30
      colossalai/gemini/ophooks/param_trace_hook.py
  3. 14
      colossalai/gemini/tensor_utils.py
  4. 2
      tests/test_gemini/test_param_tracer.py

12
colossalai/gemini/memory_tracer/param_tracer_wrapper.py

@ -1,9 +1,11 @@
import torch.nn
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook
from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float
from functools import partial
__all__ = ['ParamTracerWrapper']
@ -13,17 +15,21 @@ class ParamTracerWrapper():
super().__init__()
self.module = module
self.dtype = dtype
self.param_op_hook = ParamTracerHook()
self.param_op_hook = ParamTracerHook(dtype)
for p in module.parameters():
assert isinstance(p, ColoParameter)
p.data = p.data.to(dtype)
if p.requires_grad:
p.register_hook(partial(self.grad_handle))
self._cast_buffers_to_cuda_dtype()
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def grad_handle(self, grad):
free_storage(grad)
def _pre_forward(self):
self.param_op_hook.mem_monitor.start()

30
colossalai/gemini/ophooks/param_trace_hook.py

@ -7,6 +7,7 @@ import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.tensor.param_op_hook import ParamOpHook
from colossalai.gemini.tensor_utils import free_storage, alloc_storage
class TrainingPhase(Enum):
@ -16,25 +17,26 @@ class TrainingPhase(Enum):
class ParamTracerHook(ParamOpHook):
def __init__(self) -> None:
def __init__(self, dtype: torch.dtype = torch.half) -> None:
super().__init__()
self._training_phase = TrainingPhase.FORWARD
self.mem_monitor = SyncCudaMemoryMonitor()
self._non_model_data_list = []
self._model_data_list = []
self.dtype = dtype
def _move_params_to_dev(self, params, dev: str) -> int:
assert isinstance(dev, str), f"device should be a str not torch.device"
comm_volume = 0
def _free_cuda_params(self, params):
for p in params:
if p.data.device.type != dev:
p.data = p.data.to(dev)
comm_volume += p.data.numel() * p.data.element_size()
if p.grad is not None:
if p.grad.device.type != dev:
p.grad = p.grad.to(dev)
comm_volume += p.grad.numel() * p.grad.element_size()
return comm_volume
free_storage(p.data)
def _allocate_params_on_cuda(self, params):
for p in params:
cur_dev = p.data.device.type
if cur_dev == "cpu":
# p.data = p.data.to("cuda")
p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype)
elif cur_dev == "cuda":
alloc_storage(p.data)
def sample_model_data(self, params):
data_volume = 0
@ -49,12 +51,12 @@ class ParamTracerHook(ParamOpHook):
cuda_volume = self.mem_monitor.finish()
if len(self._model_data_list):
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
self._move_params_to_dev(params, 'cuda')
self._allocate_params_on_cuda(params)
self.sample_model_data(params)
self.mem_monitor.start()
def post_op(self, params):
self._move_params_to_dev(params, 'cpu')
self._free_cuda_params(params)
def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)

14
colossalai/gemini/tensor_utils.py

@ -3,6 +3,20 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if isinstance(tensor, StatefulTensor):
t = tensor.payload

2
tests/test_gemini/test_param_tracer.py

@ -16,7 +16,7 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
model.backward(loss)
def run_param_wrapper_testing():
test_models = ['simple_net']
test_models = ['simple_net', 'repeated_computed_layers', 'nested_model']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)

Loading…
Cancel
Save