mirror of https://github.com/hpcaitech/ColossalAI
[zero] get memory usage of sharded optim v2. (#542)
parent
a30e2b4c24
commit
c11ff81b15
|
@ -1,18 +1,11 @@
|
|||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
|
||||
import torch
|
||||
from typing import Union, Tuple, Optional
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
||||
if isinstance(t, ShardedTensor):
|
||||
target = t.payload
|
||||
else:
|
||||
target = t
|
||||
return target.numel() * target.element_size()
|
||||
|
||||
|
||||
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
"""
|
||||
Trace the model memory usage.
|
||||
|
|
|
@ -1,12 +1,32 @@
|
|||
from psutil import cpu_count
|
||||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
from typing import Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||
|
||||
|
||||
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, ShardedTensor]) -> Tuple[int, int]:
|
||||
if isinstance(tensor, ShardedTensor):
|
||||
t = tensor.payload
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
t = tensor
|
||||
else:
|
||||
return 0, 0
|
||||
|
||||
cuda_use, cpu_use = 0, 0
|
||||
|
||||
mem_use = t.numel() * t.element_size()
|
||||
if t.device.type == 'cuda':
|
||||
cuda_use += mem_use
|
||||
elif t.device.type == 'cpu':
|
||||
cpu_use += mem_use
|
||||
|
||||
return cuda_use, cpu_use
|
||||
|
||||
|
||||
def colo_set_process_memory_fraction(ratio: float) -> None:
|
||||
"""colo_set_process_memory_fraction
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
from os import stat
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -16,7 +17,7 @@ from torch.distributed import ProcessGroup
|
|||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -26,14 +27,20 @@ class OptimState(Enum):
|
|||
|
||||
class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO).
|
||||
|
||||
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
|
||||
|
||||
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
|
||||
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
https://arxiv.org/abs/2108.05818
|
||||
|
||||
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
|
||||
which is detected by a runtime memory tracer.
|
||||
|
||||
We place as many OS chunks in the margin space as possible.
|
||||
The size of margin space can be controlled by `gpu_margin_mem_ratio`
|
||||
|
||||
The size of margin space can be controlled by `gpu_margin_mem_ratio`。
|
||||
If it is set as 0.0, it is the same as classical ZeRO optimizer.
|
||||
|
||||
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`.
|
||||
|
@ -99,7 +106,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
|
||||
self._logger = get_dist_logger()
|
||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||
|
||||
# Store fp32 param shards
|
||||
self.master_params: Dict[Parameter, Tensor] = {}
|
||||
|
@ -119,6 +126,37 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# So we gather here
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Get the memory usage of the optimizer. Including master_params (param fp32),
|
||||
momentum (self.state[p]['exp_avg']) variance (self.state[p]['exp_avg_sq'])
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: cuda/cpu memory usage in Byte.
|
||||
"""
|
||||
cuda_use = 0
|
||||
cpu_use = 0
|
||||
|
||||
def update_mem_use(t):
|
||||
nonlocal cuda_use
|
||||
nonlocal cpu_use
|
||||
t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t)
|
||||
cuda_use += t_cuda_use
|
||||
cpu_use += t_cpu_use
|
||||
|
||||
for _, p_fp32 in self.master_params.items():
|
||||
update_mem_use(p_fp32)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.optim.state[p]
|
||||
for k, v in state.items():
|
||||
update_mem_use(v)
|
||||
|
||||
return cuda_use, cpu_use
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._maybe_move_fp32_shards()
|
||||
|
||||
|
@ -130,7 +168,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self.grad_scaler.update(found_inf)
|
||||
|
||||
if found_inf:
|
||||
self._logger.info('found inf during ShardedOptimV2 step')
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self.zero_grad()
|
||||
return
|
||||
|
||||
|
@ -142,8 +180,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# Now p.data is sharded
|
||||
# So optimizer states are sharded naturally
|
||||
|
||||
self._logger.debug(f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
|
||||
self._logger.debug(f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
# Copy master param data (fp32) to payload of col_attr (fp16)
|
||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||
# a chunk.
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Optional, Tuple
|
||||
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
@ -55,10 +56,9 @@ class ShardedParamV2(object):
|
|||
assert isinstance(t, torch.Tensor)
|
||||
nonlocal cuda_mem_use
|
||||
nonlocal cpu_mem_use
|
||||
if t.device.type == 'cpu':
|
||||
cpu_mem_use += t.numel() * t.element_size()
|
||||
elif t.device.type == 'cuda':
|
||||
cuda_mem_use += t.numel() * t.element_size()
|
||||
t_cuda, t_cpu = colo_tensor_mem_usage(t)
|
||||
cuda_mem_use += t_cuda
|
||||
cpu_mem_use += t_cpu
|
||||
|
||||
address_set = set()
|
||||
_update_mem_use(self.sharded_data_tensor.payload)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -57,11 +58,12 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
|||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(convert_fp16=True,
|
||||
target_device=torch.device(f'cpu:0'),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True,
|
||||
rm_torch_payload_on_the_fly=False):
|
||||
with ZeroInitContext(
|
||||
convert_fp16=True,
|
||||
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True,
|
||||
rm_torch_payload_on_the_fly=False):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
shard_strategy,
|
||||
|
|
Loading…
Reference in New Issue