|
|
|
@ -7,6 +7,7 @@ import torch
|
|
|
|
|
|
|
|
|
|
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor |
|
|
|
|
from colossalai.tensor.param_op_hook import ParamOpHook |
|
|
|
|
from colossalai.gemini.tensor_utils import free_storage, alloc_storage |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainingPhase(Enum): |
|
|
|
@ -16,25 +17,26 @@ class TrainingPhase(Enum):
|
|
|
|
|
|
|
|
|
|
class ParamTracerHook(ParamOpHook): |
|
|
|
|
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
|
def __init__(self, dtype: torch.dtype = torch.half) -> None: |
|
|
|
|
super().__init__() |
|
|
|
|
self._training_phase = TrainingPhase.FORWARD |
|
|
|
|
self.mem_monitor = SyncCudaMemoryMonitor() |
|
|
|
|
self._non_model_data_list = [] |
|
|
|
|
self._model_data_list = [] |
|
|
|
|
self.dtype = dtype |
|
|
|
|
|
|
|
|
|
def _move_params_to_dev(self, params, dev: str) -> int: |
|
|
|
|
assert isinstance(dev, str), f"device should be a str not torch.device" |
|
|
|
|
comm_volume = 0 |
|
|
|
|
def _free_cuda_params(self, params): |
|
|
|
|
for p in params: |
|
|
|
|
if p.data.device.type != dev: |
|
|
|
|
p.data = p.data.to(dev) |
|
|
|
|
comm_volume += p.data.numel() * p.data.element_size() |
|
|
|
|
if p.grad is not None: |
|
|
|
|
if p.grad.device.type != dev: |
|
|
|
|
p.grad = p.grad.to(dev) |
|
|
|
|
comm_volume += p.grad.numel() * p.grad.element_size() |
|
|
|
|
return comm_volume |
|
|
|
|
free_storage(p.data) |
|
|
|
|
|
|
|
|
|
def _allocate_params_on_cuda(self, params): |
|
|
|
|
for p in params: |
|
|
|
|
cur_dev = p.data.device.type |
|
|
|
|
if cur_dev == "cpu": |
|
|
|
|
# p.data = p.data.to("cuda") |
|
|
|
|
p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype) |
|
|
|
|
elif cur_dev == "cuda": |
|
|
|
|
alloc_storage(p.data) |
|
|
|
|
|
|
|
|
|
def sample_model_data(self, params): |
|
|
|
|
data_volume = 0 |
|
|
|
@ -49,12 +51,12 @@ class ParamTracerHook(ParamOpHook):
|
|
|
|
|
cuda_volume = self.mem_monitor.finish() |
|
|
|
|
if len(self._model_data_list): |
|
|
|
|
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1]) |
|
|
|
|
self._move_params_to_dev(params, 'cuda') |
|
|
|
|
self._allocate_params_on_cuda(params) |
|
|
|
|
self.sample_model_data(params) |
|
|
|
|
self.mem_monitor.start() |
|
|
|
|
|
|
|
|
|
def post_op(self, params): |
|
|
|
|
self._move_params_to_dev(params, 'cpu') |
|
|
|
|
self._free_cuda_params(params) |
|
|
|
|
|
|
|
|
|
def pre_forward(self, params: List[torch.Tensor]) -> None: |
|
|
|
|
self.pre_op(params) |
|
|
|
|