mirror of https://github.com/hpcaitech/ColossalAI
[zero] non model data tracing (#545)
parent
73d36618a6
commit
53b1b6e340
|
@ -1,12 +1,25 @@
|
|||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
|
||||
import torch
|
||||
from typing import Union, Tuple, Optional
|
||||
from typing import Tuple, Optional
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
|
||||
"""Trace the optimizer memory usage
|
||||
|
||||
Args:
|
||||
optim (ShardedOptimV2): an instance of ShardedOptimver
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: cuda/cpu memory usage in Byte
|
||||
"""
|
||||
if optim is None:
|
||||
return 0, 0
|
||||
assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()"
|
||||
return optim.get_memory_usage()
|
||||
|
||||
|
||||
def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
"""
|
||||
Trace the model memory usage.
|
||||
Args:
|
||||
|
@ -15,6 +28,8 @@ def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
Returns:
|
||||
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
|
||||
"""
|
||||
if model is None:
|
||||
return 0, 0
|
||||
|
||||
def _get_tensor_mem_use(t: Optional[torch.Tensor]):
|
||||
if t is None:
|
||||
|
@ -31,9 +46,9 @@ def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
cpu_mem_usage = 0
|
||||
for param in model.parameters():
|
||||
if hasattr(param, 'col_attr'):
|
||||
para_cuda, param_cpu = param.col_attr.get_memory_usage()
|
||||
cuda_mem_usage += para_cuda
|
||||
cpu_mem_usage += param_cpu
|
||||
t_cuda, t_cpu = param.col_attr.get_memory_usage()
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
else:
|
||||
t_cuda, t_cpu = _get_tensor_mem_use(param.data)
|
||||
cuda_mem_usage += t_cuda
|
||||
|
@ -54,6 +69,7 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
|||
def __init__(self) -> None:
|
||||
self._logger = DistributedLogger("ModelDataTracer")
|
||||
self._model = None
|
||||
self._opitimizer = None
|
||||
|
||||
def _get_mem_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
|
@ -61,14 +77,20 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
|||
Returns:
|
||||
Tuple[int, int]: cuda, cpu mem usage
|
||||
"""
|
||||
if self._model is None:
|
||||
self._logger.warning("The Global ModelDataTracer is using, but no model is registered on it.")
|
||||
return 0, 0
|
||||
return col_model_data_mem_usage(self._model)
|
||||
cuda_use_opt, cpu_use_opt = colo_model_optimizer_usage(self._opitimizer)
|
||||
cuda_use_model, cpu_use_model = colo_model_mem_usage(self._model)
|
||||
return cuda_use_opt + cuda_use_model, cpu_use_opt + cpu_use_model
|
||||
|
||||
def register_model(self, model) -> None:
|
||||
if self._model is not None:
|
||||
self._logger.warning("ModelDataTracer has already registered a model")
|
||||
self._model = model
|
||||
|
||||
def register_optimizer(self, optimizer) -> None:
|
||||
if self._opitimizer is not None:
|
||||
self._logger.warning("ModelDataTracer has already registered an optimizer")
|
||||
self._opitimizer = optimizer
|
||||
|
||||
@property
|
||||
def cpu_usage(self):
|
||||
_, cpu_usage = self._get_mem_usage()
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from psutil import cpu_count
|
||||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.shard_utils.commons import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
@ -56,5 +56,5 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
||||
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||
t.reset_payload(gathered_payload)
|
||||
colo_model_data_tensor_move_inline(t, target_device, use_tracer=False)
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t.is_sharded = False
|
||||
|
|
|
@ -119,6 +119,10 @@ class ShardedModelV2(nn.Module):
|
|||
self._cuda_margin_space = 0
|
||||
self.reuse_fp16_shard = reuse_fp16_shard
|
||||
|
||||
@property
|
||||
def use_memory_tracer(self):
|
||||
return self._use_memory_tracer
|
||||
|
||||
@property
|
||||
def cuda_margin_space(self):
|
||||
return self._cuda_margin_space
|
||||
|
@ -150,8 +154,8 @@ class ShardedModelV2(nn.Module):
|
|||
def _update_memstats(self):
|
||||
if self._iter_cnter == 0 and self._memstats_collector:
|
||||
self._memstats_collector.finish_collection()
|
||||
self.logger.info(f'model data cuda, {self._memstats_collector.model_data_cuda}')
|
||||
self.logger.info(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}')
|
||||
self.logger.debug(f'model data cuda, {self._memstats_collector.model_data_cuda}')
|
||||
self.logger.debug(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}')
|
||||
|
||||
if self._memstats_collector:
|
||||
self._memstats_collector.reset_sampling_cnter()
|
||||
|
|
|
@ -5,6 +5,11 @@ from typing import Dict, Optional, Tuple
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -14,10 +19,10 @@ from colossalai.utils.memory_utils.utils import (colo_model_tensor_clone, colo_t
|
|||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -75,6 +80,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
growth_interval: float = 1000,
|
||||
hysteresis: float = 2,
|
||||
max_scale: int = 2**32,
|
||||
use_memory_tracer=False,
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||
|
@ -129,6 +135,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
|
||||
self._use_memory_tracer = self.model.use_memory_tracer
|
||||
if self._use_memory_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Get the memory usage of the optimizer. Including master_params (param fp32),
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.testing import parameterize, rerun_on_exception
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
col_model_data_mem_usage
|
||||
colo_model_mem_usage
|
||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
|
@ -51,7 +51,7 @@ def run_model_test(init_device_type, shard_strategy_class):
|
|||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
cuda_mem_use, cpu_mem_use = col_model_data_mem_usage(model)
|
||||
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
|
||||
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
||||
logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
|
||||
|
|
|
@ -63,11 +63,13 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
|||
shard_param=True,
|
||||
rm_torch_payload_on_the_fly=False):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
reuse_fp16_shard=use_cpuadam)
|
||||
reuse_fp16_shard=use_cpuadam,
|
||||
)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
|
|
Loading…
Reference in New Issue