Browse Source

[hotfix] revert bug PRs (#2016)

pull/2017/head^2
Jiarui Fang 2 years ago committed by GitHub
parent
commit
0b0d8f9e17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/gemini/memory_tracer/__init__.py
  2. 51
      colossalai/gemini/memory_tracer/param_tracer_wrapper.py
  3. 81
      colossalai/gemini/ophooks/param_trace_hook.py
  4. 47
      tests/test_gemini/test_mem_tracer_paramOP.py

3
colossalai/gemini/memory_tracer/__init__.py

@ -4,9 +4,8 @@ from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
from .module_tracer_wrapper import MemtracerWrapper # isort:skip
from .param_tracer_wrapper import ParamWrapper # isort:skip
__all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemtracerWrapper', 'ParamWrapper'
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemtracerWrapper'
]

51
colossalai/gemini/memory_tracer/param_tracer_wrapper.py

@ -1,51 +0,0 @@
import torch.nn
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.ophooks import ParamMemHook
from colossalai.nn.parallel.data_parallel import _cast_float
class ParamWrapper():
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
super().__init__()
self.module = module
self.dtype = dtype
self.param_op_hook = ParamMemHook()
for p in module.parameters():
assert isinstance(p, ColoParameter)
p.data = p.data.to(dtype)
self._cast_buffers_to_cuda_dtype()
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _pre_forward(self):
self.param_op_hook.mem_monitor.start()
def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype)
self.module.zero_grad(set_to_none=True)
self._pre_forward()
with ParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
return outputs
def backward(self, loss):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def _post_backward(self):
cuda_volume = self.param_op_hook.mem_monitor.finish()
last_model_data = self.param_op_hook._model_data_list[-1]
self.param_op_hook._non_model_data_list.append(cuda_volume - last_model_data)
def _cast_buffers_to_cuda_dtype(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.data.to(self.dtype)

81
colossalai/gemini/ophooks/param_trace_hook.py

@ -1,81 +0,0 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import List
import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.tensor.param_op_hook import ParamOpHook
class TrainingPhase(Enum):
FORWARD = 0
BACKWARD = 1
class ParamMemHook(ParamOpHook):
def __init__(self) -> None:
super().__init__()
self._training_phase = TrainingPhase.FORWARD
self.mem_monitor = SyncCudaMemoryMonitor()
self._non_model_data_list = []
self._model_data_list = []
def _move_params_to_dev(self, params, dev: str) -> int:
assert isinstance(dev, str), f"device should be a str not torch.device"
comm_volume = 0
for p in params:
if p.data.device.type != dev:
p.data = p.data.to(dev)
comm_volume += p.data.numel() * p.data.element_size()
if p.grad is not None:
if p.grad.device.type != dev:
p.grad = p.grad.to(dev)
comm_volume += p.grad.numel() * p.grad.element_size()
return comm_volume
def sample_model_data(self, params):
data_volume = 0
for p in params:
data_volume += p.data.numel() * p.data.element_size()
if self._training_phase == TrainingPhase.BACKWARD:
# add param.grad, actually param.grad is None in this time
data_volume *= 2
self._model_data_list.append(data_volume)
def pre_op(self, params):
cuda_volume = self.mem_monitor.finish()
if len(self._model_data_list):
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
self._move_params_to_dev(params, 'cuda')
self.sample_model_data(params)
self.mem_monitor.start()
def post_op(self, params):
self._move_params_to_dev(params, 'cpu')
def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_forward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
def pre_backward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_backward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
old_training_phase = self._training_phase
try:
self._training_phase = training_phase
yield
finally:
self._training_phase = old_training_phase
switch_to_backward = switch_training_phase
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)

47
tests/test_gemini/test_mem_tracer_paramOP.py

@ -1,47 +0,0 @@
import numpy as np
import torch
from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamWrapper
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
model.backward(loss)
def run_param_wrapper_testing():
test_models = ['repeated_computed_layers', 'simple_net', 'no_leaf_module', 'bert']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
with ColoInitContext(device=torch.device('cpu')):
model = model_builder(checkpoint=False)
model = ParamWrapper(model)
for i, (data, label) in enumerate(train_dataloader):
if i > 1:
break
data = data.cuda()
label = label.cuda()
run_fwd_bwd(model, data, label, criterion, False)
cuda_non_model_data_list = np.array(model.param_op_hook._non_model_data_list) / 1024 ** 2
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
# print(model.param_op_hook._non_model_data_list)
del model
if __name__ == '__main__':
run_param_wrapper_testing()
Loading…
Cancel
Save