mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] free and allocate cuda memory by tensor.storage, add grad hook (#2040)
parent
1e885329f4
commit
6a9158f1fa
|
@ -1,9 +1,11 @@
|
|||
import torch.nn
|
||||
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook
|
||||
from colossalai.gemini.tensor_utils import free_storage
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from functools import partial
|
||||
|
||||
|
||||
__all__ = ['ParamTracerWrapper']
|
||||
|
||||
|
@ -13,17 +15,21 @@ class ParamTracerWrapper():
|
|||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.param_op_hook = ParamTracerHook()
|
||||
self.param_op_hook = ParamTracerHook(dtype)
|
||||
|
||||
for p in module.parameters():
|
||||
assert isinstance(p, ColoParameter)
|
||||
p.data = p.data.to(dtype)
|
||||
if p.requires_grad:
|
||||
p.register_hook(partial(self.grad_handle))
|
||||
|
||||
self._cast_buffers_to_cuda_dtype()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def grad_handle(self, grad):
|
||||
free_storage(grad)
|
||||
|
||||
def _pre_forward(self):
|
||||
self.param_op_hook.mem_monitor.start()
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
|||
|
||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||
from colossalai.gemini.tensor_utils import free_storage, alloc_storage
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
|
@ -16,25 +17,26 @@ class TrainingPhase(Enum):
|
|||
|
||||
class ParamTracerHook(ParamOpHook):
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, dtype: torch.dtype = torch.half) -> None:
|
||||
super().__init__()
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
||||
self._non_model_data_list = []
|
||||
self._model_data_list = []
|
||||
self.dtype = dtype
|
||||
|
||||
def _move_params_to_dev(self, params, dev: str) -> int:
|
||||
assert isinstance(dev, str), f"device should be a str not torch.device"
|
||||
comm_volume = 0
|
||||
def _free_cuda_params(self, params):
|
||||
for p in params:
|
||||
if p.data.device.type != dev:
|
||||
p.data = p.data.to(dev)
|
||||
comm_volume += p.data.numel() * p.data.element_size()
|
||||
if p.grad is not None:
|
||||
if p.grad.device.type != dev:
|
||||
p.grad = p.grad.to(dev)
|
||||
comm_volume += p.grad.numel() * p.grad.element_size()
|
||||
return comm_volume
|
||||
free_storage(p.data)
|
||||
|
||||
def _allocate_params_on_cuda(self, params):
|
||||
for p in params:
|
||||
cur_dev = p.data.device.type
|
||||
if cur_dev == "cpu":
|
||||
# p.data = p.data.to("cuda")
|
||||
p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype)
|
||||
elif cur_dev == "cuda":
|
||||
alloc_storage(p.data)
|
||||
|
||||
def sample_model_data(self, params):
|
||||
data_volume = 0
|
||||
|
@ -49,12 +51,12 @@ class ParamTracerHook(ParamOpHook):
|
|||
cuda_volume = self.mem_monitor.finish()
|
||||
if len(self._model_data_list):
|
||||
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
|
||||
self._move_params_to_dev(params, 'cuda')
|
||||
self._allocate_params_on_cuda(params)
|
||||
self.sample_model_data(params)
|
||||
self.mem_monitor.start()
|
||||
|
||||
def post_op(self, params):
|
||||
self._move_params_to_dev(params, 'cpu')
|
||||
self._free_cuda_params(params)
|
||||
|
||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||||
self.pre_op(params)
|
||||
|
|
|
@ -3,6 +3,20 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
|
|||
from typing import Union, Tuple
|
||||
|
||||
|
||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||
return tensor.storage().size() == 0
|
||||
|
||||
|
||||
def free_storage(tensor: torch.Tensor) -> None:
|
||||
if not is_storage_empty(tensor):
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
|
||||
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
if is_storage_empty(tensor):
|
||||
tensor.storage().resize_(tensor.numel())
|
||||
|
||||
|
||||
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
t = tensor.payload
|
||||
|
|
|
@ -16,7 +16,7 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
|
|||
model.backward(loss)
|
||||
|
||||
def run_param_wrapper_testing():
|
||||
test_models = ['simple_net']
|
||||
test_models = ['simple_net', 'repeated_computed_layers', 'nested_model']
|
||||
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
|
|
Loading…
Reference in New Issue