Browse Source

[zero] get memory usage of sharded optim v2. (#542)

pull/540/head^2
Jiarui Fang 3 years ago committed by GitHub
parent
commit
c11ff81b15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      colossalai/utils/memory_tracer/model_data_memtracer.py
  2. 22
      colossalai/utils/memory_utils/utils.py
  3. 53
      colossalai/zero/sharded_optim/sharded_optim_v2.py
  4. 8
      colossalai/zero/sharded_param/sharded_param.py
  5. 12
      tests/test_zero_data_parallel/test_sharded_optim_v2.py

9
colossalai/utils/memory_tracer/model_data_memtracer.py

@ -1,18 +1,11 @@
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
import torch
from typing import Union, Tuple, Optional
from colossalai.logging import DistributedLogger
def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
if isinstance(t, ShardedTensor):
target = t.payload
else:
target = t
return target.numel() * target.element_size()
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
"""
Trace the model memory usage.

22
colossalai/utils/memory_utils/utils.py

@ -1,12 +1,32 @@
from psutil import cpu_count
import torch
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from typing import Union
from typing import Tuple, Union
_GLOBAL_CUDA_MEM_FRACTION = 1.0
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, ShardedTensor]) -> Tuple[int, int]:
if isinstance(tensor, ShardedTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
else:
return 0, 0
cuda_use, cpu_use = 0, 0
mem_use = t.numel() * t.element_size()
if t.device.type == 'cuda':
cuda_use += mem_use
elif t.device.type == 'cpu':
cpu_use += mem_use
return cuda_use, cpu_use
def colo_set_process_memory_fraction(ratio: float) -> None:
"""colo_set_process_memory_fraction

53
colossalai/zero/sharded_optim/sharded_optim_v2.py

@ -1,5 +1,6 @@
from enum import Enum
from typing import Dict, Optional
from os import stat
from typing import Dict, Optional, Tuple
import torch
import torch.distributed as dist
@ -16,7 +17,7 @@ from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
class OptimState(Enum):
@ -26,14 +27,20 @@ class OptimState(Enum):
class ShardedOptimizerV2(ColossalaiOptimizer):
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO).
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
which is detected by a runtime memory tracer.
We place as many OS chunks in the margin space as possible.
The size of margin space can be controlled by `gpu_margin_mem_ratio`
The size of margin space can be controlled by `gpu_margin_mem_ratio`
If it is set as 0.0, it is the same as classical ZeRO optimizer.
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`.
@ -99,7 +106,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger()
self._logger = get_dist_logger("ShardedOptimizerV2")
# Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {}
@ -119,6 +126,37 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# So we gather here
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
def get_memory_usage(self) -> Tuple[int, int]:
"""
Get the memory usage of the optimizer. Including master_params (param fp32),
momentum (self.state[p]['exp_avg']) variance (self.state[p]['exp_avg_sq'])
Returns:
Tuple[int, int]: cuda/cpu memory usage in Byte.
"""
cuda_use = 0
cpu_use = 0
def update_mem_use(t):
nonlocal cuda_use
nonlocal cpu_use
t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t)
cuda_use += t_cuda_use
cpu_use += t_cpu_use
for _, p_fp32 in self.master_params.items():
update_mem_use(p_fp32)
for group in self.optim.param_groups:
for p in group['params']:
state = self.optim.state[p]
for k, v in state.items():
update_mem_use(v)
return cuda_use, cpu_use
def step(self, *args, **kwargs):
self._maybe_move_fp32_shards()
@ -130,7 +168,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.grad_scaler.update(found_inf)
if found_inf:
self._logger.info('found inf during ShardedOptimV2 step')
self._logger.warning('found inf during ShardedOptimV2 step')
self.zero_grad()
return
@ -142,8 +180,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Now p.data is sharded
# So optimizer states are sharded naturally
self._logger.debug(f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
ret = self.optim.step(*args, **kwargs)
self._logger.debug(f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
# Copy master param data (fp32) to payload of col_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.

8
colossalai/zero/sharded_param/sharded_param.py

@ -2,6 +2,7 @@ import torch
import torch.distributed as dist
from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional, Tuple
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
class ShardedParamV2(object):
@ -55,10 +56,9 @@ class ShardedParamV2(object):
assert isinstance(t, torch.Tensor)
nonlocal cuda_mem_use
nonlocal cpu_mem_use
if t.device.type == 'cpu':
cpu_mem_use += t.numel() * t.element_size()
elif t.device.type == 'cuda':
cuda_mem_use += t.numel() * t.element_size()
t_cuda, t_cpu = colo_tensor_mem_usage(t)
cuda_mem_use += t_cuda
cpu_mem_use += t_cpu
address_set = set()
_update_mem_use(self.sharded_data_tensor.payload)

12
tests/test_zero_data_parallel/test_sharded_optim_v2.py

@ -1,6 +1,7 @@
from functools import partial
import colossalai
from colossalai.utils.cuda import get_current_device
import pytest
import torch
import torch.distributed as dist
@ -57,11 +58,12 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(convert_fp16=True,
target_device=torch.device(f'cpu:0'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=False):
with ZeroInitContext(
convert_fp16=True,
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=False):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model,
shard_strategy,

Loading…
Cancel
Save