mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] free and allocate cuda memory by tensor.storage, add grad hook (#2040)
parent
1e885329f4
commit
6a9158f1fa
|
@ -1,9 +1,11 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook
|
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook
|
||||||
|
from colossalai.gemini.tensor_utils import free_storage
|
||||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['ParamTracerWrapper']
|
__all__ = ['ParamTracerWrapper']
|
||||||
|
|
||||||
|
@ -13,17 +15,21 @@ class ParamTracerWrapper():
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.param_op_hook = ParamTracerHook()
|
self.param_op_hook = ParamTracerHook(dtype)
|
||||||
|
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
assert isinstance(p, ColoParameter)
|
|
||||||
p.data = p.data.to(dtype)
|
p.data = p.data.to(dtype)
|
||||||
|
if p.requires_grad:
|
||||||
|
p.register_hook(partial(self.grad_handle))
|
||||||
|
|
||||||
self._cast_buffers_to_cuda_dtype()
|
self._cast_buffers_to_cuda_dtype()
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def grad_handle(self, grad):
|
||||||
|
free_storage(grad)
|
||||||
|
|
||||||
def _pre_forward(self):
|
def _pre_forward(self):
|
||||||
self.param_op_hook.mem_monitor.start()
|
self.param_op_hook.mem_monitor.start()
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
||||||
|
|
||||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||||
|
from colossalai.gemini.tensor_utils import free_storage, alloc_storage
|
||||||
|
|
||||||
|
|
||||||
class TrainingPhase(Enum):
|
class TrainingPhase(Enum):
|
||||||
|
@ -16,25 +17,26 @@ class TrainingPhase(Enum):
|
||||||
|
|
||||||
class ParamTracerHook(ParamOpHook):
|
class ParamTracerHook(ParamOpHook):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, dtype: torch.dtype = torch.half) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._training_phase = TrainingPhase.FORWARD
|
self._training_phase = TrainingPhase.FORWARD
|
||||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
self.mem_monitor = SyncCudaMemoryMonitor()
|
||||||
self._non_model_data_list = []
|
self._non_model_data_list = []
|
||||||
self._model_data_list = []
|
self._model_data_list = []
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
def _move_params_to_dev(self, params, dev: str) -> int:
|
def _free_cuda_params(self, params):
|
||||||
assert isinstance(dev, str), f"device should be a str not torch.device"
|
|
||||||
comm_volume = 0
|
|
||||||
for p in params:
|
for p in params:
|
||||||
if p.data.device.type != dev:
|
free_storage(p.data)
|
||||||
p.data = p.data.to(dev)
|
|
||||||
comm_volume += p.data.numel() * p.data.element_size()
|
def _allocate_params_on_cuda(self, params):
|
||||||
if p.grad is not None:
|
for p in params:
|
||||||
if p.grad.device.type != dev:
|
cur_dev = p.data.device.type
|
||||||
p.grad = p.grad.to(dev)
|
if cur_dev == "cpu":
|
||||||
comm_volume += p.grad.numel() * p.grad.element_size()
|
# p.data = p.data.to("cuda")
|
||||||
return comm_volume
|
p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype)
|
||||||
|
elif cur_dev == "cuda":
|
||||||
|
alloc_storage(p.data)
|
||||||
|
|
||||||
def sample_model_data(self, params):
|
def sample_model_data(self, params):
|
||||||
data_volume = 0
|
data_volume = 0
|
||||||
|
@ -49,12 +51,12 @@ class ParamTracerHook(ParamOpHook):
|
||||||
cuda_volume = self.mem_monitor.finish()
|
cuda_volume = self.mem_monitor.finish()
|
||||||
if len(self._model_data_list):
|
if len(self._model_data_list):
|
||||||
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
|
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
|
||||||
self._move_params_to_dev(params, 'cuda')
|
self._allocate_params_on_cuda(params)
|
||||||
self.sample_model_data(params)
|
self.sample_model_data(params)
|
||||||
self.mem_monitor.start()
|
self.mem_monitor.start()
|
||||||
|
|
||||||
def post_op(self, params):
|
def post_op(self, params):
|
||||||
self._move_params_to_dev(params, 'cpu')
|
self._free_cuda_params(params)
|
||||||
|
|
||||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||||||
self.pre_op(params)
|
self.pre_op(params)
|
||||||
|
|
|
@ -3,6 +3,20 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||||
|
return tensor.storage().size() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def free_storage(tensor: torch.Tensor) -> None:
|
||||||
|
if not is_storage_empty(tensor):
|
||||||
|
tensor.storage().resize_(0)
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||||
|
if is_storage_empty(tensor):
|
||||||
|
tensor.storage().resize_(tensor.numel())
|
||||||
|
|
||||||
|
|
||||||
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
|
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
|
||||||
if isinstance(tensor, StatefulTensor):
|
if isinstance(tensor, StatefulTensor):
|
||||||
t = tensor.payload
|
t = tensor.payload
|
||||||
|
|
|
@ -16,7 +16,7 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
|
||||||
model.backward(loss)
|
model.backward(loss)
|
||||||
|
|
||||||
def run_param_wrapper_testing():
|
def run_param_wrapper_testing():
|
||||||
test_models = ['simple_net']
|
test_models = ['simple_net', 'repeated_computed_layers', 'nested_model']
|
||||||
|
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
|
|
Loading…
Reference in New Issue