|
|
|
@ -7,6 +7,7 @@ import torch
|
|
|
|
|
|
|
|
|
|
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
|
|
|
|
from colossalai.tensor.param_op_hook import ParamOpHook
|
|
|
|
|
from colossalai.gemini.tensor_utils import free_storage, alloc_storage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainingPhase(Enum):
|
|
|
|
@ -16,25 +17,26 @@ class TrainingPhase(Enum):
|
|
|
|
|
|
|
|
|
|
class ParamTracerHook(ParamOpHook):
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
def __init__(self, dtype: torch.dtype = torch.half) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self._training_phase = TrainingPhase.FORWARD
|
|
|
|
|
self.mem_monitor = SyncCudaMemoryMonitor()
|
|
|
|
|
self._non_model_data_list = []
|
|
|
|
|
self._model_data_list = []
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
|
|
|
|
|
def _move_params_to_dev(self, params, dev: str) -> int:
|
|
|
|
|
assert isinstance(dev, str), f"device should be a str not torch.device"
|
|
|
|
|
comm_volume = 0
|
|
|
|
|
def _free_cuda_params(self, params):
|
|
|
|
|
for p in params:
|
|
|
|
|
if p.data.device.type != dev:
|
|
|
|
|
p.data = p.data.to(dev)
|
|
|
|
|
comm_volume += p.data.numel() * p.data.element_size()
|
|
|
|
|
if p.grad is not None:
|
|
|
|
|
if p.grad.device.type != dev:
|
|
|
|
|
p.grad = p.grad.to(dev)
|
|
|
|
|
comm_volume += p.grad.numel() * p.grad.element_size()
|
|
|
|
|
return comm_volume
|
|
|
|
|
free_storage(p.data)
|
|
|
|
|
|
|
|
|
|
def _allocate_params_on_cuda(self, params):
|
|
|
|
|
for p in params:
|
|
|
|
|
cur_dev = p.data.device.type
|
|
|
|
|
if cur_dev == "cpu":
|
|
|
|
|
# p.data = p.data.to("cuda")
|
|
|
|
|
p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype)
|
|
|
|
|
elif cur_dev == "cuda":
|
|
|
|
|
alloc_storage(p.data)
|
|
|
|
|
|
|
|
|
|
def sample_model_data(self, params):
|
|
|
|
|
data_volume = 0
|
|
|
|
@ -49,12 +51,12 @@ class ParamTracerHook(ParamOpHook):
|
|
|
|
|
cuda_volume = self.mem_monitor.finish()
|
|
|
|
|
if len(self._model_data_list):
|
|
|
|
|
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
|
|
|
|
|
self._move_params_to_dev(params, 'cuda')
|
|
|
|
|
self._allocate_params_on_cuda(params)
|
|
|
|
|
self.sample_model_data(params)
|
|
|
|
|
self.mem_monitor.start()
|
|
|
|
|
|
|
|
|
|
def post_op(self, params):
|
|
|
|
|
self._move_params_to_dev(params, 'cpu')
|
|
|
|
|
self._free_cuda_params(params)
|
|
|
|
|
|
|
|
|
|
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
|
|
|
|
self.pre_op(params)
|
|
|
|
|