[zero] sharded optim support hybrid cpu adam (#486)

* sharded optim support hybrid cpu adam

* update unit test

* polish docstring
pull/491/head
ver217 2022-03-22 14:56:59 +08:00 committed by GitHub
parent b334822163
commit 62b0a8d644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 48 deletions

View File

@ -1,9 +1,13 @@
import torch
import math
import torch
class CPUAdam(torch.optim.Optimizer):
optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
def __init__(self,
model_params,
@ -106,10 +110,6 @@ class CPUAdam(torch.optim.Optimizer):
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], self.loss_scale)
elif target_device.type == 'cuda':
# FIXME() prepare grad on cuda
if p.grad.device.type == 'cpu':
p.grad = p.grad.to(target_device)
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"

View File

@ -70,7 +70,7 @@ class ShardedModelV2(nn.Module):
sharded.append(param.col_attr.param_is_sharded)
unsharded.append(not param.col_attr.param_is_sharded)
assert all(sharded) or all(
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded nwo.'
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
self.shard_param = all(sharded)
self.module = module
@ -96,6 +96,10 @@ class ShardedModelV2(nn.Module):
self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
for param in module.parameters():
# Init `offload_fp32_grad`
param.col_attr.offload_fp32_grad = self._cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set
@ -184,7 +188,7 @@ class ShardedModelV2(nn.Module):
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if self._cpu_offload:
if p.col_attr.offload_fp32_grad:
col_move_to_cpu(grad)
if p.col_attr.fp32_grad is not None:
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))

View File

@ -28,48 +28,29 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3.
You must use `ShardedOptimizerV2` with `ShardedModelV2`.
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
:type sharded_model: sharded_model
:param optimizer: A Optimizer instance.
:type optimizer: Optimizer
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param initial_scale: initial scale used by DynamicGradScaler
:type initial_scale: float
:param min_scale: min scale used by DynamicGradScaler
:type min_scale: float
:param growth_factor: growth_factor used by DynamicGradScaler
:type growth_factor: float
:param backoff_factor: backoff_factor used by DynamicGradScaler
:type backoff_factor: float
:param growth_interval: growth_interval used by DynamicGradScaler
:type growth_interval: float
:param hysteresis: hysteresis used by DynamicGradScaler
:type hysteresis: float
:param max_scale: max_scale used by DynamicGradScaler
:type max_scale: float
:param dp_process_group: data paralle process group
:type dp_process_group: Optional[ProcessGroup]
:param mp_process_group: model paralle process group
:type mp_process_group: Optional[ProcessGroup]
"""
Args:
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
optimizer (Optimizer): An Optimizer instance.
cpu_offload (bool, optional): Is offloading the optimizer states to CPU.. Defaults to False.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer. Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
"""
def __init__(self,
sharded_model: ShardedModelV2,
optimizer: Optimizer,
cpu_offload: bool = False,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
@ -88,6 +69,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
raise RuntimeError(
f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload"
)
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
# and it must set `num_fp32_shards_per_param` correctly
self._should_move_fp32_shards_h2d: bool = cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
optimizer, 'num_fp32_shards_per_param', 0) >= 2
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
@ -122,6 +110,20 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
def step(self, *args, **kwargs):
if self._should_move_fp32_shards_h2d:
self._should_move_fp32_shards_h2d = False
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups:
for p in group['params']:
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.col_attr.offload_fp32_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()

View File

@ -13,6 +13,8 @@ class ShardedParamV2(object):
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
self.fp16_grad: Optional[torch.Tensor] = None
self.fp32_grad: Optional[torch.Tensor] = None
# This attribute must be initialized in ShardedModel
self.offload_fp32_grad: bool = False
# make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model.

View File

@ -5,6 +5,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize
from colossalai.utils import free_port
@ -18,7 +19,6 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_sharded_params_padding
from colossalai.amp import convert_to_apex_amp
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@ -42,12 +42,15 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@parameterize("cpu_offload", [True, False])
@parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam):
return
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
@ -61,7 +64,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None)
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
@ -71,7 +75,11 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
optimizer_class = CPUAdam
optim = optimizer_class(model.parameters(), lr=1e-3)
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, cpu_offload=cpu_offload, initial_scale=2**5)
sharded_optim = ShardedOptimizerV2(zero_model,
sharded_optim,
cpu_offload=cpu_offload,
initial_scale=2**5,
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)