[zero] non model data tracing (#545)

pull/541/head^2
Jiarui Fang 2022-03-29 15:45:48 +08:00 committed by GitHub
parent 73d36618a6
commit 53b1b6e340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 64 additions and 27 deletions

View File

@ -1,12 +1,25 @@
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
import torch
from typing import Union, Tuple, Optional
from typing import Tuple, Optional
from colossalai.logging import DistributedLogger
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
"""Trace the optimizer memory usage
Args:
optim (ShardedOptimV2): an instance of ShardedOptimver
Returns:
Tuple[int, int]: cuda/cpu memory usage in Byte
"""
if optim is None:
return 0, 0
assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()"
return optim.get_memory_usage()
def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
"""
Trace the model memory usage.
Args:
@ -15,6 +28,8 @@ def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
Returns:
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
"""
if model is None:
return 0, 0
def _get_tensor_mem_use(t: Optional[torch.Tensor]):
if t is None:
@ -31,9 +46,9 @@ def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
cpu_mem_usage = 0
for param in model.parameters():
if hasattr(param, 'col_attr'):
para_cuda, param_cpu = param.col_attr.get_memory_usage()
cuda_mem_usage += para_cuda
cpu_mem_usage += param_cpu
t_cuda, t_cpu = param.col_attr.get_memory_usage()
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu
else:
t_cuda, t_cpu = _get_tensor_mem_use(param.data)
cuda_mem_usage += t_cuda
@ -54,6 +69,7 @@ class ModelDataTracer(metaclass=SingletonMeta):
def __init__(self) -> None:
self._logger = DistributedLogger("ModelDataTracer")
self._model = None
self._opitimizer = None
def _get_mem_usage(self) -> Tuple[int, int]:
"""
@ -61,14 +77,20 @@ class ModelDataTracer(metaclass=SingletonMeta):
Returns:
Tuple[int, int]: cuda, cpu mem usage
"""
if self._model is None:
self._logger.warning("The Global ModelDataTracer is using, but no model is registered on it.")
return 0, 0
return col_model_data_mem_usage(self._model)
cuda_use_opt, cpu_use_opt = colo_model_optimizer_usage(self._opitimizer)
cuda_use_model, cpu_use_model = colo_model_mem_usage(self._model)
return cuda_use_opt + cuda_use_model, cpu_use_opt + cpu_use_model
def register_model(self, model) -> None:
if self._model is not None:
self._logger.warning("ModelDataTracer has already registered a model")
self._model = model
def register_optimizer(self, optimizer) -> None:
if self._opitimizer is not None:
self._logger.warning("ModelDataTracer has already registered an optimizer")
self._opitimizer = optimizer
@property
def cpu_usage(self):
_, cpu_usage = self._get_mem_usage()

View File

@ -1,4 +1,3 @@
from psutil import cpu_count
import torch
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
@ -56,5 +56,5 @@ class TensorShardStrategy(BaseShardStrategy):
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload)
colo_model_data_tensor_move_inline(t, target_device, use_tracer=False)
colo_model_data_tensor_move_inline(t, target_device)
t.is_sharded = False

View File

@ -119,6 +119,10 @@ class ShardedModelV2(nn.Module):
self._cuda_margin_space = 0
self.reuse_fp16_shard = reuse_fp16_shard
@property
def use_memory_tracer(self):
return self._use_memory_tracer
@property
def cuda_margin_space(self):
return self._cuda_margin_space
@ -150,8 +154,8 @@ class ShardedModelV2(nn.Module):
def _update_memstats(self):
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
self.logger.info(f'model data cuda, {self._memstats_collector.model_data_cuda}')
self.logger.info(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}')
self.logger.debug(f'model data cuda, {self._memstats_collector.model_data_cuda}')
self.logger.debug(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}')
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()

View File

@ -5,6 +5,11 @@ from typing import Dict, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@ -14,10 +19,10 @@ from colossalai.utils.memory_utils.utils import (colo_model_tensor_clone, colo_t
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
class OptimState(Enum):
@ -75,6 +80,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval: float = 1000,
hysteresis: float = 2,
max_scale: int = 2**32,
use_memory_tracer=False,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
@ -129,6 +135,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
self._use_memory_tracer = self.model.use_memory_tracer
if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
def get_memory_usage(self) -> Tuple[int, int]:
"""
Get the memory usage of the optimizer. Including master_params (param fp32),

View File

@ -12,7 +12,7 @@ from colossalai.testing import parameterize, rerun_on_exception
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \
col_model_data_mem_usage
colo_model_mem_usage
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -51,7 +51,7 @@ def run_model_test(init_device_type, shard_strategy_class):
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
cuda_mem_use, cpu_mem_use = col_model_data_mem_usage(model)
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
model_data_cuda_mem_MB = cuda_mem_use / 1e6
logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6

View File

@ -63,11 +63,13 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
shard_param=True,
rm_torch_payload_on_the_fly=False):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=use_cpuadam)
zero_model = ShardedModelV2(
zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=use_cpuadam,
)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)