mirror of https://github.com/hpcaitech/ColossalAI
[zero] get memory usage of sharded optim v2. (#542)
parent
a30e2b4c24
commit
c11ff81b15
|
@ -1,18 +1,11 @@
|
||||||
from colossalai.context.singleton_meta import SingletonMeta
|
from colossalai.context.singleton_meta import SingletonMeta
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
|
||||||
import torch
|
import torch
|
||||||
from typing import Union, Tuple, Optional
|
from typing import Union, Tuple, Optional
|
||||||
from colossalai.logging import DistributedLogger
|
from colossalai.logging import DistributedLogger
|
||||||
|
|
||||||
|
|
||||||
def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
|
||||||
if isinstance(t, ShardedTensor):
|
|
||||||
target = t.payload
|
|
||||||
else:
|
|
||||||
target = t
|
|
||||||
return target.numel() * target.element_size()
|
|
||||||
|
|
||||||
|
|
||||||
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Trace the model memory usage.
|
Trace the model memory usage.
|
||||||
|
|
|
@ -1,12 +1,32 @@
|
||||||
|
from psutil import cpu_count
|
||||||
import torch
|
import torch
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
|
||||||
from typing import Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, ShardedTensor]) -> Tuple[int, int]:
|
||||||
|
if isinstance(tensor, ShardedTensor):
|
||||||
|
t = tensor.payload
|
||||||
|
elif isinstance(tensor, torch.Tensor):
|
||||||
|
t = tensor
|
||||||
|
else:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
cuda_use, cpu_use = 0, 0
|
||||||
|
|
||||||
|
mem_use = t.numel() * t.element_size()
|
||||||
|
if t.device.type == 'cuda':
|
||||||
|
cuda_use += mem_use
|
||||||
|
elif t.device.type == 'cpu':
|
||||||
|
cpu_use += mem_use
|
||||||
|
|
||||||
|
return cuda_use, cpu_use
|
||||||
|
|
||||||
|
|
||||||
def colo_set_process_memory_fraction(ratio: float) -> None:
|
def colo_set_process_memory_fraction(ratio: float) -> None:
|
||||||
"""colo_set_process_memory_fraction
|
"""colo_set_process_memory_fraction
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional
|
from os import stat
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -16,7 +17,7 @@ from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
|
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
|
||||||
|
|
||||||
|
|
||||||
class OptimState(Enum):
|
class OptimState(Enum):
|
||||||
|
@ -26,14 +27,20 @@ class OptimState(Enum):
|
||||||
|
|
||||||
class ShardedOptimizerV2(ColossalaiOptimizer):
|
class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO).
|
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO).
|
||||||
|
|
||||||
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
|
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
|
||||||
|
|
||||||
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
|
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
|
||||||
|
|
||||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||||
https://arxiv.org/abs/2108.05818
|
https://arxiv.org/abs/2108.05818
|
||||||
|
|
||||||
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
|
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
|
||||||
which is detected by a runtime memory tracer.
|
which is detected by a runtime memory tracer.
|
||||||
|
|
||||||
We place as many OS chunks in the margin space as possible.
|
We place as many OS chunks in the margin space as possible.
|
||||||
The size of margin space can be controlled by `gpu_margin_mem_ratio`
|
|
||||||
|
The size of margin space can be controlled by `gpu_margin_mem_ratio`。
|
||||||
If it is set as 0.0, it is the same as classical ZeRO optimizer.
|
If it is set as 0.0, it is the same as classical ZeRO optimizer.
|
||||||
|
|
||||||
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`.
|
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`.
|
||||||
|
@ -99,7 +106,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
hysteresis=hysteresis,
|
hysteresis=hysteresis,
|
||||||
max_scale=max_scale)
|
max_scale=max_scale)
|
||||||
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
|
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||||
|
|
||||||
# Store fp32 param shards
|
# Store fp32 param shards
|
||||||
self.master_params: Dict[Parameter, Tensor] = {}
|
self.master_params: Dict[Parameter, Tensor] = {}
|
||||||
|
@ -119,6 +126,37 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# So we gather here
|
# So we gather here
|
||||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
|
|
||||||
|
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Get the memory usage of the optimizer. Including master_params (param fp32),
|
||||||
|
momentum (self.state[p]['exp_avg']) variance (self.state[p]['exp_avg_sq'])
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: cuda/cpu memory usage in Byte.
|
||||||
|
"""
|
||||||
|
cuda_use = 0
|
||||||
|
cpu_use = 0
|
||||||
|
|
||||||
|
def update_mem_use(t):
|
||||||
|
nonlocal cuda_use
|
||||||
|
nonlocal cpu_use
|
||||||
|
t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t)
|
||||||
|
cuda_use += t_cuda_use
|
||||||
|
cpu_use += t_cpu_use
|
||||||
|
|
||||||
|
for _, p_fp32 in self.master_params.items():
|
||||||
|
update_mem_use(p_fp32)
|
||||||
|
for group in self.optim.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
state = self.optim.state[p]
|
||||||
|
for k, v in state.items():
|
||||||
|
update_mem_use(v)
|
||||||
|
|
||||||
|
return cuda_use, cpu_use
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
self._maybe_move_fp32_shards()
|
self._maybe_move_fp32_shards()
|
||||||
|
|
||||||
|
@ -130,7 +168,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self.grad_scaler.update(found_inf)
|
self.grad_scaler.update(found_inf)
|
||||||
|
|
||||||
if found_inf:
|
if found_inf:
|
||||||
self._logger.info('found inf during ShardedOptimV2 step')
|
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -142,8 +180,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# Now p.data is sharded
|
# Now p.data is sharded
|
||||||
# So optimizer states are sharded naturally
|
# So optimizer states are sharded naturally
|
||||||
|
|
||||||
|
self._logger.debug(f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
ret = self.optim.step(*args, **kwargs)
|
ret = self.optim.step(*args, **kwargs)
|
||||||
|
|
||||||
|
self._logger.debug(f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||||
|
ranks=[0])
|
||||||
# Copy master param data (fp32) to payload of col_attr (fp16)
|
# Copy master param data (fp32) to payload of col_attr (fp16)
|
||||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||||
# a chunk.
|
# a chunk.
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.zero.sharded_param import ShardedTensor
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
|
||||||
|
|
||||||
|
|
||||||
class ShardedParamV2(object):
|
class ShardedParamV2(object):
|
||||||
|
@ -55,10 +56,9 @@ class ShardedParamV2(object):
|
||||||
assert isinstance(t, torch.Tensor)
|
assert isinstance(t, torch.Tensor)
|
||||||
nonlocal cuda_mem_use
|
nonlocal cuda_mem_use
|
||||||
nonlocal cpu_mem_use
|
nonlocal cpu_mem_use
|
||||||
if t.device.type == 'cpu':
|
t_cuda, t_cpu = colo_tensor_mem_usage(t)
|
||||||
cpu_mem_use += t.numel() * t.element_size()
|
cuda_mem_use += t_cuda
|
||||||
elif t.device.type == 'cuda':
|
cpu_mem_use += t_cpu
|
||||||
cuda_mem_use += t.numel() * t.element_size()
|
|
||||||
|
|
||||||
address_set = set()
|
address_set = set()
|
||||||
_update_mem_use(self.sharded_data_tensor.payload)
|
_update_mem_use(self.sharded_data_tensor.payload)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -57,11 +58,12 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
with ZeroInitContext(convert_fp16=True,
|
with ZeroInitContext(
|
||||||
target_device=torch.device(f'cpu:0'),
|
convert_fp16=True,
|
||||||
shard_strategy=shard_strategy,
|
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
|
||||||
shard_param=True,
|
shard_strategy=shard_strategy,
|
||||||
rm_torch_payload_on_the_fly=False):
|
shard_param=True,
|
||||||
|
rm_torch_payload_on_the_fly=False):
|
||||||
zero_model = model_builder(checkpoint=True)
|
zero_model = model_builder(checkpoint=True)
|
||||||
zero_model = ShardedModelV2(zero_model,
|
zero_model = ShardedModelV2(zero_model,
|
||||||
shard_strategy,
|
shard_strategy,
|
||||||
|
|
Loading…
Reference in New Issue