diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 59ee5f9bd..ba3c188ac 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -56,6 +56,8 @@ class CPUAdam(torch.optim.Optimizer): bias_correction2, loss_scale, use_adamw=False): + # FIXME(ver217): remove the below line when replace torch adam with fused adam + grad = grad.float() if loss_scale is not None: grad.div_(loss_scale) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 70e14548b..afce76933 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -29,24 +29,22 @@ class ShardedModelV2(nn.Module): compared to classic data parallelism while the computational granularity and communication efficiency are retained. Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`. - :param module: A sharded module, which must be initialized by `ZeroInitContext`. - :type module: nn.Module - :param shard_strategy: A shard strategy to manage shard behavior. - :type shard_strategy: BaseShardStrategy - :param process_group: Data parallel process group, defaults to None - :type process_group: Optional[ProcessGroup], optional - :param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`. - :type reduce_scatter_process_group: Optional[ProcessGroup], optional - :param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25 - :type reduce_scatter_bucket_size_mb: int, optional - :param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False - :type fp32_reduce_scatter: bool, optional - :param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None - :type offload_config: Optional[dict], optional - :param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0 - :type gradient_predivide_factor: Optional[float], optional - :param use_memory_tracer: Whether to use memoty tracer, defaults to False - :type use_memory_tracer: bool, optional + Args: + module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`. + shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior. + process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None. + reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. + Generally, it should be `None`, and it's the same as `process_group`. Defaults to None. + reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25. + fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False. + offload_config (Optional[dict], optional): We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload. Defaults to None. + gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0. + use_memory_tracer (bool, optional): Whether to use memoty tracer. Defaults to False. + reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. + Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. + In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). + We find that PyTorch's optimizers don't support mixed precision, + so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. """ def __init__(self, @@ -58,7 +56,8 @@ class ShardedModelV2(nn.Module): fp32_reduce_scatter: bool = False, offload_config: Optional[dict] = None, gradient_predivide_factor: Optional[float] = 1.0, - use_memory_tracer: bool = False): + use_memory_tracer: bool = False, + reuse_fp16_shard: bool = False): super().__init__() self.logger = get_dist_logger() @@ -97,8 +96,8 @@ class ShardedModelV2(nn.Module): self.fp32_reduce_scatter = fp32_reduce_scatter self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False for param in module.parameters(): - # Init `offload_fp32_grad` - param.col_attr.offload_fp32_grad = self._cpu_offload + # Init `offload_grad` + param.col_attr.offload_grad = self._cpu_offload # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem # So we use 1.0 as the default gradient_predivide_factor @@ -114,6 +113,7 @@ class ShardedModelV2(nn.Module): self._require_backward_grad_sync: bool = True self._cuda_margin_space = 0 + self.reuse_fp16_shard = reuse_fp16_shard @property def cuda_margin_space(self): @@ -143,11 +143,7 @@ class ShardedModelV2(nn.Module): for ophook in self._ophook_list: ophook.post_iter() - @torch.no_grad() - def _post_backward_operations(self) -> None: - """ - The method includes operations required to be processed after backward - """ + def _update_memstats(self): if self._iter_cnter == 0 and self._memstats_collector: self._memstats_collector.finish_collection() if self._memstats_collector: @@ -160,6 +156,13 @@ class ShardedModelV2(nn.Module): self._iter_cnter += 1 + @torch.no_grad() + def _post_backward_operations(self) -> None: + """ + The method includes operations required to be processed after backward + """ + self._update_memstats() + if self._require_backward_grad_sync: # Flush any unreduced buckets in the post_backward stream. with torch.cuda.stream(self.comm_stream): @@ -171,9 +174,11 @@ class ShardedModelV2(nn.Module): self.reducer.free() # In case some post bwd hook is not fired if self.shard_param: + tensor_list = [] for p in self.module.parameters(): if not p.col_attr.param_is_sharded: - self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group) + tensor_list.append(p.col_attr.sharded_data_tensor) + self.shard_strategy.shard(tensor_list, self.process_group) for p in self.module.parameters(): p.col_attr.bwd_count = 0 if not p.requires_grad: @@ -191,13 +196,17 @@ class ShardedModelV2(nn.Module): # If world size == 1 and sharded param, # the shape `grad` is the same as unsharded param # So we can just use `view(-1)` to ensure grad is a flat tensor shard - grad = cast_tensor_to_fp32(p.col_attr.fp16_grad) - if p.col_attr.offload_fp32_grad: + if self.reuse_fp16_shard: + grad = p.col_attr.sharded_data_tensor.payload + else: + grad = cast_tensor_to_fp32(p.col_attr.fp16_grad) + if p.col_attr.offload_grad: col_move_to_cpu(grad) if p.col_attr.fp32_grad is not None: + assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad)) grad = p.col_attr.fp32_grad - p.grad.data = grad.view(-1) + p.grad.data = grad p.col_attr.fp16_grad = None p.col_attr.fp32_grad = None @@ -250,11 +259,15 @@ class ShardedModelV2(nn.Module): return empty_grad def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: + reduced_grad = reduced_grad.view(-1) if self.gradient_postdivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. reduced_grad.data.div_(self.gradient_postdivide_factor) - - param.col_attr.fp16_grad = reduced_grad.data + if self.reuse_fp16_shard: + param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data) + param.col_attr.sharded_data_tensor.is_sharded = True + else: + param.col_attr.fp16_grad = reduced_grad.data def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()], diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 2109d4499..b3507c132 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -224,5 +224,5 @@ class ShardedOptimizerV2(ColossalaiOptimizer): if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: self.master_params[p] = self.master_params[p].to(torch.cuda.current_device()) p.grad.data = p.grad.data.to(torch.cuda.current_device()) - p.col_attr.offload_fp32_grad = False + p.col_attr.offload_grad = False fp32_shards_used_cuda_margin_mem += shard_mem diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 2a777d14e..5826dde96 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -14,7 +14,7 @@ class ShardedParamV2(object): self.fp16_grad: Optional[torch.Tensor] = None self.fp32_grad: Optional[torch.Tensor] = None # This attribute must be initialized in ShardedModel - self.offload_fp32_grad: bool = False + self.offload_grad: bool = False # make sure the shared param is the only owner of payload # The param.data maybe used to init the other part of the model. diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 7e6f881dc..70166e121 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -16,7 +16,8 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, offload_config=None, gradient_predivide_factor=1.0, use_memory_tracer=False, - shard_strategy=TensorShardStrategy()) + shard_strategy=TensorShardStrategy(), + reuse_fp16_shard=False) _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False, initial_scale=2**5, @@ -116,10 +117,13 @@ def check_params_padding(model, zero_model, loose=False): assert allclose(p, zero_p, loose=loose) -def check_sharded_params_padding(model, zero_model, loose=False): +def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): rank = dist.get_rank() for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float() + if reuse_fp16_shard: + zero_p = zero_p.data.to(p.device).float() + else: + zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float() chunks = torch.flatten(p).chunk(dist.get_world_size()) if rank >= len(chunks): continue diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 6de799c80..a8d9c0874 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -18,7 +18,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP -from common import CONFIG, check_sharded_params_padding +from common import CONFIG, check_sharded_model_params def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): @@ -65,7 +65,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g zero_model = ShardedModelV2(zero_model, shard_strategy, offload_config=dict(device='cpu') if cpu_offload else None, - use_memory_tracer=gpu_margin_mem_ratio > 0.0) + use_memory_tracer=gpu_margin_mem_ratio > 0.0, + reuse_fp16_shard=use_cpuadam) model = model_builder(checkpoint=True).half() col_model_deepcopy(zero_model, model) @@ -92,7 +93,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g data, label = data.cuda(), label.cuda() _run_step(apex_model, apex_optimizer, data, label, criterion, False) _run_step(zero_model, sharded_optim, data, label, criterion, False) - check_sharded_params_padding(model, zero_model, loose=True) + check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) for param in model.parameters(): assert not has_inf_or_nan(param) diff --git a/tests/test_zero_data_parallel/test_zero_engine.py b/tests/test_zero_data_parallel/test_zero_engine.py index 56ad85203..c1fb6b2bb 100644 --- a/tests/test_zero_data_parallel/test_zero_engine.py +++ b/tests/test_zero_data_parallel/test_zero_engine.py @@ -16,7 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP -from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding) +from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params) def run_dist(rank, world_size, port, parallel_config): @@ -87,7 +87,7 @@ def run_dist(rank, world_size, port, parallel_config): if parallel_config == MP_PARALLEL_CONFIG: check_params(torch_model, colo_model, loose=True) elif parallel_config == ZERO_PARALLEL_CONFIG: - check_sharded_params_padding(torch_model, colo_model, loose=True) + check_sharded_model_params(torch_model, colo_model, loose=True) # FIXME: enable this test in next PR