diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index f3ef91d0e..59ee5f9bd 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,9 +1,13 @@ -import torch import math +import torch + class CPUAdam(torch.optim.Optimizer): optimizer_id = 0 + # Number of fp32 shards for per parameter + # Param weight, grad, momentum and variance + num_fp32_shards_per_param = 4 def __init__(self, model_params, @@ -106,10 +110,6 @@ class CPUAdam(torch.optim.Optimizer): group['weight_decay'], group['bias_correction'], p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], self.loss_scale) elif target_device.type == 'cuda': - # FIXME() prepare grad on cuda - if p.grad.device.type == 'cpu': - p.grad = p.grad.to(target_device) - assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 7ca598e82..4f0297d34 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -70,7 +70,7 @@ class ShardedModelV2(nn.Module): sharded.append(param.col_attr.param_is_sharded) unsharded.append(not param.col_attr.param_is_sharded) assert all(sharded) or all( - unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded nwo.' + unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.' self.shard_param = all(sharded) self.module = module @@ -96,6 +96,10 @@ class ShardedModelV2(nn.Module): self.fp32_reduce_scatter = fp32_reduce_scatter self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False + for param in module.parameters(): + # Init `offload_fp32_grad` + param.col_attr.offload_fp32_grad = self._cpu_offload + # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem # So we use 1.0 as the default gradient_predivide_factor # However, if you set gradient_predivide_factor to None, we will set @@ -184,7 +188,7 @@ class ShardedModelV2(nn.Module): # the shape `grad` is the same as unsharded param # So we can just use `view(-1)` to ensure grad is a flat tensor shard grad = cast_tensor_to_fp32(p.col_attr.fp16_grad) - if self._cpu_offload: + if p.col_attr.offload_fp32_grad: col_move_to_cpu(grad) if p.col_attr.fp32_grad is not None: p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad)) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 5ec57c083..8ccec731d 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -28,48 +28,29 @@ class ShardedOptimizerV2(ColossalaiOptimizer): """A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3. You must use `ShardedOptimizerV2` with `ShardedModelV2`. - :param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the - shard strategy provided by sharded model to shard param fp32 tensors. - :type sharded_model: sharded_model - - :param optimizer: A Optimizer instance. - :type optimizer: Optimizer - - :param cpu_offload: is offloading the optimizer states to CPU. - :type cpu_offload: bool - - :param initial_scale: initial scale used by DynamicGradScaler - :type initial_scale: float - - :param min_scale: min scale used by DynamicGradScaler - :type min_scale: float - - :param growth_factor: growth_factor used by DynamicGradScaler - :type growth_factor: float - - :param backoff_factor: backoff_factor used by DynamicGradScaler - :type backoff_factor: float - - :param growth_interval: growth_interval used by DynamicGradScaler - :type growth_interval: float - - :param hysteresis: hysteresis used by DynamicGradScaler - :type hysteresis: float - - :param max_scale: max_scale used by DynamicGradScaler - :type max_scale: float - - :param dp_process_group: data paralle process group - :type dp_process_group: Optional[ProcessGroup] - - :param mp_process_group: model paralle process group - :type mp_process_group: Optional[ProcessGroup] - """ + Args: + sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the + shard strategy provided by sharded model to shard param fp32 tensors. + optimizer (Optimizer): An Optimizer instance. + cpu_offload (bool, optional): Is offloading the optimizer states to CPU.. Defaults to False. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. Defaults to 0.0. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None. + mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None. + """ def __init__(self, sharded_model: ShardedModelV2, optimizer: Optimizer, cpu_offload: bool = False, + gpu_margin_mem_ratio: float = 0.0, initial_scale: float = 2**32, min_scale: float = 1, growth_factor: float = 2, @@ -88,6 +69,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer): raise RuntimeError( f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload" ) + self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid + # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, + # and it must set `num_fp32_shards_per_param` correctly + self._should_move_fp32_shards_h2d: bool = cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr( + optimizer, 'num_fp32_shards_per_param', 0) >= 2 self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') self.optim_state: OptimState = OptimState.UNSCALED self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) @@ -122,6 +110,20 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) def step(self, *args, **kwargs): + if self._should_move_fp32_shards_h2d: + self._should_move_fp32_shards_h2d = False + available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio + fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param + fp32_shards_used_cuda_margin_mem = 0 + for group in self.optim.param_groups: + for p in group['params']: + shard_mem = self.master_params[p].numel() * self.master_params[p].element_size() + if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: + self.master_params[p] = self.master_params[p].to(torch.cuda.current_device()) + p.grad.data = p.grad.data.to(torch.cuda.current_device()) + p.col_attr.offload_fp32_grad = False + fp32_shards_used_cuda_margin_mem += shard_mem + # unscale grads if scaled if self.optim_state == OptimState.SCALED: self._unscale_grads() diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index de5f8f849..2a777d14e 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -13,6 +13,8 @@ class ShardedParamV2(object): self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group) self.fp16_grad: Optional[torch.Tensor] = None self.fp32_grad: Optional[torch.Tensor] = None + # This attribute must be initialized in ShardedModel + self.offload_fp32_grad: bool = False # make sure the shared param is the only owner of payload # The param.data maybe used to init the other part of the model. diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 3f3149400..6de799c80 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -5,6 +5,7 @@ import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp +from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import CPUAdam from colossalai.testing import parameterize from colossalai.utils import free_port @@ -18,7 +19,6 @@ from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP from common import CONFIG, check_sharded_params_padding -from colossalai.amp import convert_to_apex_amp def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): @@ -42,12 +42,15 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): @parameterize("cpu_offload", [True, False]) @parameterize("use_cpuadam", [True, False]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam): +@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) +def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio): test_models = ['repeated_computed_layers', 'resnet18', 'bert'] shard_strategy = shard_strategy_class() if use_cpuadam and cpu_offload is False: return + if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam): + return for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -61,7 +64,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam): zero_model = model_builder(checkpoint=True) zero_model = ShardedModelV2(zero_model, shard_strategy, - offload_config=dict(device='cpu') if cpu_offload else None) + offload_config=dict(device='cpu') if cpu_offload else None, + use_memory_tracer=gpu_margin_mem_ratio > 0.0) model = model_builder(checkpoint=True).half() col_model_deepcopy(zero_model, model) @@ -71,7 +75,11 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam): optimizer_class = CPUAdam optim = optimizer_class(model.parameters(), lr=1e-3) sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, cpu_offload=cpu_offload, initial_scale=2**5) + sharded_optim = ShardedOptimizerV2(zero_model, + sharded_optim, + cpu_offload=cpu_offload, + initial_scale=2**5, + gpu_margin_mem_ratio=gpu_margin_mem_ratio) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)