[zero] sharded optim support hybrid cpu adam (#486)

* sharded optim support hybrid cpu adam

* update unit test

* polish docstring
pull/491/head
ver217 3 years ago committed by GitHub
parent b334822163
commit 62b0a8d644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,13 @@
import torch
import math import math
import torch
class CPUAdam(torch.optim.Optimizer): class CPUAdam(torch.optim.Optimizer):
optimizer_id = 0 optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
def __init__(self, def __init__(self,
model_params, model_params,
@ -106,10 +110,6 @@ class CPUAdam(torch.optim.Optimizer):
group['weight_decay'], group['bias_correction'], p.data, p.grad.data, group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], self.loss_scale) state['exp_avg'], state['exp_avg_sq'], self.loss_scale)
elif target_device.type == 'cuda': elif target_device.type == 'cuda':
# FIXME() prepare grad on cuda
if p.grad.device.type == 'cpu':
p.grad = p.grad.to(target_device)
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"

@ -70,7 +70,7 @@ class ShardedModelV2(nn.Module):
sharded.append(param.col_attr.param_is_sharded) sharded.append(param.col_attr.param_is_sharded)
unsharded.append(not param.col_attr.param_is_sharded) unsharded.append(not param.col_attr.param_is_sharded)
assert all(sharded) or all( assert all(sharded) or all(
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded nwo.' unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
self.shard_param = all(sharded) self.shard_param = all(sharded)
self.module = module self.module = module
@ -96,6 +96,10 @@ class ShardedModelV2(nn.Module):
self.fp32_reduce_scatter = fp32_reduce_scatter self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
for param in module.parameters():
# Init `offload_fp32_grad`
param.col_attr.offload_fp32_grad = self._cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor # So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set # However, if you set gradient_predivide_factor to None, we will set
@ -184,7 +188,7 @@ class ShardedModelV2(nn.Module):
# the shape `grad` is the same as unsharded param # the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard # So we can just use `view(-1)` to ensure grad is a flat tensor shard
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad) grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if self._cpu_offload: if p.col_attr.offload_fp32_grad:
col_move_to_cpu(grad) col_move_to_cpu(grad)
if p.col_attr.fp32_grad is not None: if p.col_attr.fp32_grad is not None:
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad)) p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))

@ -28,48 +28,29 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3. """A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3.
You must use `ShardedOptimizerV2` with `ShardedModelV2`. You must use `ShardedOptimizerV2` with `ShardedModelV2`.
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the Args:
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors. shard strategy provided by sharded model to shard param fp32 tensors.
:type sharded_model: sharded_model optimizer (Optimizer): An Optimizer instance.
cpu_offload (bool, optional): Is offloading the optimizer states to CPU.. Defaults to False.
:param optimizer: A Optimizer instance. gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
:type optimizer: Optimizer which will be used when using hybrid CPU optimizer. Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
:param cpu_offload: is offloading the optimizer states to CPU. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
:type cpu_offload: bool growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
:param initial_scale: initial scale used by DynamicGradScaler growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
:type initial_scale: float hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
:param min_scale: min scale used by DynamicGradScaler dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
:type min_scale: float mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
:param growth_factor: growth_factor used by DynamicGradScaler
:type growth_factor: float
:param backoff_factor: backoff_factor used by DynamicGradScaler
:type backoff_factor: float
:param growth_interval: growth_interval used by DynamicGradScaler
:type growth_interval: float
:param hysteresis: hysteresis used by DynamicGradScaler
:type hysteresis: float
:param max_scale: max_scale used by DynamicGradScaler
:type max_scale: float
:param dp_process_group: data paralle process group
:type dp_process_group: Optional[ProcessGroup]
:param mp_process_group: model paralle process group
:type mp_process_group: Optional[ProcessGroup]
""" """
def __init__(self, def __init__(self,
sharded_model: ShardedModelV2, sharded_model: ShardedModelV2,
optimizer: Optimizer, optimizer: Optimizer,
cpu_offload: bool = False, cpu_offload: bool = False,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32, initial_scale: float = 2**32,
min_scale: float = 1, min_scale: float = 1,
growth_factor: float = 2, growth_factor: float = 2,
@ -88,6 +69,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
raise RuntimeError( raise RuntimeError(
f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload" f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload"
) )
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
# and it must set `num_fp32_shards_per_param` correctly
self._should_move_fp32_shards_h2d: bool = cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
optimizer, 'num_fp32_shards_per_param', 0) >= 2
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
self.optim_state: OptimState = OptimState.UNSCALED self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
@ -122,6 +110,20 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
if self._should_move_fp32_shards_h2d:
self._should_move_fp32_shards_h2d = False
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups:
for p in group['params']:
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.col_attr.offload_fp32_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem
# unscale grads if scaled # unscale grads if scaled
if self.optim_state == OptimState.SCALED: if self.optim_state == OptimState.SCALED:
self._unscale_grads() self._unscale_grads()

@ -13,6 +13,8 @@ class ShardedParamV2(object):
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group) self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
self.fp16_grad: Optional[torch.Tensor] = None self.fp16_grad: Optional[torch.Tensor] = None
self.fp32_grad: Optional[torch.Tensor] = None self.fp32_grad: Optional[torch.Tensor] = None
# This attribute must be initialized in ShardedModel
self.offload_fp32_grad: bool = False
# make sure the shared param is the only owner of payload # make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model. # The param.data maybe used to init the other part of the model.

@ -5,6 +5,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
@ -18,7 +19,6 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_sharded_params_padding from common import CONFIG, check_sharded_params_padding
from colossalai.amp import convert_to_apex_amp
def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@ -42,12 +42,15 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@parameterize("cpu_offload", [True, False]) @parameterize("cpu_offload", [True, False])
@parameterize("use_cpuadam", [True, False]) @parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam): @parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy_class() shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False: if use_cpuadam and cpu_offload is False:
return return
if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam):
return
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
@ -61,7 +64,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
zero_model = model_builder(checkpoint=True) zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, zero_model = ShardedModelV2(zero_model,
shard_strategy, shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None) offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0)
model = model_builder(checkpoint=True).half() model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model) col_model_deepcopy(zero_model, model)
@ -71,7 +75,11 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
optimizer_class = CPUAdam optimizer_class = CPUAdam
optim = optimizer_class(model.parameters(), lr=1e-3) optim = optimizer_class(model.parameters(), lr=1e-3)
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, cpu_offload=cpu_offload, initial_scale=2**5) sharded_optim = ShardedOptimizerV2(zero_model,
sharded_optim,
cpu_offload=cpu_offload,
initial_scale=2**5,
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)

Loading…
Cancel
Save