[zero] get memory usage of sharded optim v2. (#542)

pull/540/head^2
Jiarui Fang 2022-03-29 09:08:18 +08:00 committed by GitHub
parent a30e2b4c24
commit c11ff81b15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 81 additions and 23 deletions

View File

@ -1,18 +1,11 @@
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
import torch import torch
from typing import Union, Tuple, Optional from typing import Union, Tuple, Optional
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
if isinstance(t, ShardedTensor):
target = t.payload
else:
target = t
return target.numel() * target.element_size()
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
""" """
Trace the model memory usage. Trace the model memory usage.

View File

@ -1,12 +1,32 @@
from psutil import cpu_count
import torch import torch
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from typing import Union from typing import Tuple, Union
_GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CUDA_MEM_FRACTION = 1.0
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, ShardedTensor]) -> Tuple[int, int]:
if isinstance(tensor, ShardedTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
else:
return 0, 0
cuda_use, cpu_use = 0, 0
mem_use = t.numel() * t.element_size()
if t.device.type == 'cuda':
cuda_use += mem_use
elif t.device.type == 'cpu':
cpu_use += mem_use
return cuda_use, cpu_use
def colo_set_process_memory_fraction(ratio: float) -> None: def colo_set_process_memory_fraction(ratio: float) -> None:
"""colo_set_process_memory_fraction """colo_set_process_memory_fraction

View File

@ -1,5 +1,6 @@
from enum import Enum from enum import Enum
from typing import Dict, Optional from os import stat
from typing import Dict, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -16,7 +17,7 @@ from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
class OptimState(Enum): class OptimState(Enum):
@ -26,14 +27,20 @@ class OptimState(Enum):
class ShardedOptimizerV2(ColossalaiOptimizer): class ShardedOptimizerV2(ColossalaiOptimizer):
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO). """A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO).
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU. By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
We apply the Device-aware Operator Placement technique for OS placement from the following paper. We apply the Device-aware Operator Placement technique for OS placement from the following paper.
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818 https://arxiv.org/abs/2108.05818
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory, GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
which is detected by a runtime memory tracer. which is detected by a runtime memory tracer.
We place as many OS chunks in the margin space as possible. We place as many OS chunks in the margin space as possible.
The size of margin space can be controlled by `gpu_margin_mem_ratio`
The size of margin space can be controlled by `gpu_margin_mem_ratio`
If it is set as 0.0, it is the same as classical ZeRO optimizer. If it is set as 0.0, it is the same as classical ZeRO optimizer.
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`. NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`.
@ -99,7 +106,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis=hysteresis, hysteresis=hysteresis,
max_scale=max_scale) max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device()) self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger() self._logger = get_dist_logger("ShardedOptimizerV2")
# Store fp32 param shards # Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {} self.master_params: Dict[Parameter, Tensor] = {}
@ -119,6 +126,37 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# So we gather here # So we gather here
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
def get_memory_usage(self) -> Tuple[int, int]:
"""
Get the memory usage of the optimizer. Including master_params (param fp32),
momentum (self.state[p]['exp_avg']) variance (self.state[p]['exp_avg_sq'])
Returns:
Tuple[int, int]: cuda/cpu memory usage in Byte.
"""
cuda_use = 0
cpu_use = 0
def update_mem_use(t):
nonlocal cuda_use
nonlocal cpu_use
t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t)
cuda_use += t_cuda_use
cpu_use += t_cpu_use
for _, p_fp32 in self.master_params.items():
update_mem_use(p_fp32)
for group in self.optim.param_groups:
for p in group['params']:
state = self.optim.state[p]
for k, v in state.items():
update_mem_use(v)
return cuda_use, cpu_use
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
self._maybe_move_fp32_shards() self._maybe_move_fp32_shards()
@ -130,7 +168,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.grad_scaler.update(found_inf) self.grad_scaler.update(found_inf)
if found_inf: if found_inf:
self._logger.info('found inf during ShardedOptimV2 step') self._logger.warning('found inf during ShardedOptimV2 step')
self.zero_grad() self.zero_grad()
return return
@ -142,8 +180,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Now p.data is sharded # Now p.data is sharded
# So optimizer states are sharded naturally # So optimizer states are sharded naturally
self._logger.debug(f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
ret = self.optim.step(*args, **kwargs) ret = self.optim.step(*args, **kwargs)
self._logger.debug(f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
# Copy master param data (fp32) to payload of col_attr (fp16) # Copy master param data (fp32) to payload of col_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering # TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk. # a chunk.

View File

@ -2,6 +2,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional, Tuple from typing import Optional, Tuple
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
class ShardedParamV2(object): class ShardedParamV2(object):
@ -55,10 +56,9 @@ class ShardedParamV2(object):
assert isinstance(t, torch.Tensor) assert isinstance(t, torch.Tensor)
nonlocal cuda_mem_use nonlocal cuda_mem_use
nonlocal cpu_mem_use nonlocal cpu_mem_use
if t.device.type == 'cpu': t_cuda, t_cpu = colo_tensor_mem_usage(t)
cpu_mem_use += t.numel() * t.element_size() cuda_mem_use += t_cuda
elif t.device.type == 'cuda': cpu_mem_use += t_cpu
cuda_mem_use += t.numel() * t.element_size()
address_set = set() address_set = set()
_update_mem_use(self.sharded_data_tensor.payload) _update_mem_use(self.sharded_data_tensor.payload)

View File

@ -1,6 +1,7 @@
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.utils.cuda import get_current_device
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -57,11 +58,12 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(convert_fp16=True, with ZeroInitContext(
target_device=torch.device(f'cpu:0'), convert_fp16=True,
shard_strategy=shard_strategy, target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
shard_param=True, shard_strategy=shard_strategy,
rm_torch_payload_on_the_fly=False): shard_param=True,
rm_torch_payload_on_the_fly=False):
zero_model = model_builder(checkpoint=True) zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, zero_model = ShardedModelV2(zero_model,
shard_strategy, shard_strategy,