From 7c6c427db11b6a4ae852c488d1124e91794ac927 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 31 Mar 2022 16:26:54 +0800 Subject: [PATCH] [zero] trace states of fp16/32 grad and fp32 param (#571) --- colossalai/utils/memory_utils/utils.py | 6 +- .../zero/sharded_model/sharded_model_v2.py | 37 +++----- .../zero/sharded_optim/sharded_optim_v2.py | 88 +++++++++++-------- .../zero/sharded_param/sharded_param.py | 5 -- .../test_shard_param.py | 6 -- 5 files changed, 69 insertions(+), 73 deletions(-) diff --git a/colossalai/utils/memory_utils/utils.py b/colossalai/utils/memory_utils/utils.py index 90b7438d3..1f6fe332f 100644 --- a/colossalai/utils/memory_utils/utils.py +++ b/colossalai/utils/memory_utils/utils.py @@ -51,9 +51,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_ """ A colossal API for model data tensor move. The src and target tensors could be resident on both CPU and GPU. - + NOTE() The source tensor payload will be removed after this function. - + The function will record the communication volume between CPU and GPU. Args: t_src (Union[StatefulTensor, torch.Tensor]): source tensor @@ -93,7 +93,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}') if isinstance(target_device, int): - target_device = torch.cuda(f'device"{target_device}') + target_device = torch.device(f'cuda:{target_device}') # deal with torch.device('cpu') and torch.device('cpu:0) if t_payload.device.type == target_device.type: diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index a27da5e3b..1e60be6da 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -18,7 +18,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu) from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer -from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState) +from colossalai.zero.sharded_param.tensorful_state import TensorState from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -245,27 +245,7 @@ class ShardedModelV2(nn.Module): # We also allows to interleave no-sync pass with sync passes, if desired. if not self._require_backward_grad_sync: continue - # Reduced grad is saved in `p.colo_attr.saved_grad` - # It can be on CPU or CUDA - # It can be fp16 or fp32 - # We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`. - if self.reuse_fp16_shard: - grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload - else: - grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload) - assert isinstance(grad_fp16_payload, torch.Tensor) - if p.colo_attr.offload_grad: - colo_model_data_move_to_cpu(grad_fp16_payload) - if not p.colo_attr.saved_grad.is_null(): - assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' - # Accumulate grad, saved grad must be fp32 - p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload)) - p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload)) - else: - p.colo_attr.saved_grad.reset_payload(grad_fp16_payload) - p.grad = None - p.colo_attr.fp16_grad.set_null() @torch.no_grad() def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: @@ -322,11 +302,22 @@ class ShardedModelV2(nn.Module): if self.gradient_postdivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. reduced_grad.data.div_(self.gradient_postdivide_factor) + # FIXME(ver217): remove the below line when impl eviction policy + if param.colo_attr.offload_grad: + colo_model_data_move_to_cpu(reduced_grad) if self.reuse_fp16_shard: - param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data) + assert param.colo_attr.saved_grad.is_null( + ), 'Gradien accumulation is not supported when reuse_fp16_shard=True' + param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad) param.colo_attr.sharded_data_tensor.is_sharded = True + param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload) else: - param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data) + reduced_grad = cast_tensor_to_fp32(reduced_grad) + if param.colo_attr.saved_grad.is_null(): + param.colo_attr.saved_grad.reset_payload(reduced_grad) + else: + param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload)) + param.colo_attr.saved_grad.trans_state(TensorState.HOLD) def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()], diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 539350101..338870bde 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -12,11 +12,12 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER -from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone, +from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone, colo_tensor_mem_usage) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState) from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter @@ -112,7 +113,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._logger = get_dist_logger("ShardedOptimizerV2") # Store fp32 param shards - self.master_params: Dict[Parameter, Tensor] = {} + self.master_params: Dict[Parameter, StatefulTensor] = {} for group in self.optim.param_groups: for p in group['params']: @@ -123,7 +124,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Param is no sharded, which means we use ZeRO-2 here # As we only store param shard, we shard it here self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) - self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device) + self.master_params[p] = StatefulTensor( + cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)) if not is_param_sharded: # In this branch, there's no need to shard param # So we gather here @@ -184,13 +186,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.zero_grad() return - # assign master param pointers to p.data. - # We will not trigger data copy here. - for group in self.optim.param_groups: - for p in group['params']: - p.data = self.master_params[p] - # Now p.data is sharded - # So optimizer states are sharded naturally + self._prepare_data() self._logger.debug( f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", @@ -201,30 +197,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._logger.debug( f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", ranks=[0]) - # Copy master param data (fp32) to payload of colo_attr (fp16) - # TODO() improve efficiency by gathering tensors into a chunk and transfering - # a chunk. - for group in self.optim.param_groups: - for p in group['params']: - is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded - if not is_param_sharded: - # We use ZeRO-2 here - # The `p.colo_attr.sharded_data_tensor` saves full fp16 param - # But we only have updated fp32 param shard here - # So we first shard full fp16 param and copy fp32 param shard to it - # Then we will gather them - self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) - # We have to use `copy_payload` instead of `reset_payload` - # Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16 - - # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.colo_attr.sharded_data_tensor.reset_payload( - colo_model_tensor_clone(p.half(), torch.cuda.current_device())) - - if not is_param_sharded: - # We gather full fp16 param here - self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) - p.data = p.colo_attr.sharded_data_tensor.payload + self._write_back_data() return ret def backward(self, loss: Tensor) -> None: @@ -276,6 +249,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Because we will judge whether local grad accumulation # is enabled by wheter grad is None self.optim.zero_grad(set_to_none=True) + for group in self.optim.param_groups: + for p in group['params']: + p.colo_attr.saved_grad.set_null() def sync_grad(self): pass @@ -288,9 +264,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer): fp32_shards_used_cuda_margin_mem = 0 for group in self.optim.param_groups: for p in group['params']: - shard_mem = self.master_params[p].numel() * self.master_params[p].element_size() + shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size() if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: - self.master_params[p] = self.master_params[p].to(torch.cuda.current_device()) + colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device()) p.grad.data = p.grad.data.to(torch.cuda.current_device()) p.colo_attr.offload_grad = False fp32_shards_used_cuda_margin_mem += shard_mem @@ -298,10 +274,50 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def _prepare_grads(self): for group in self.optim.param_groups: for p in group['params']: + p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE) # FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation # If we change p.grad directly # it may raise error because of different shape/dtype/device of p.data and p.grad # We just set p.data = p.colo_attr.saved_grad.payload here p.data = p.colo_attr.saved_grad.payload p.grad = p.colo_attr.saved_grad.payload + # Set p.data to empty tensor, in case of memory leaking + p.colo_attr.remove_torch_payload() + + def _prepare_data(self): + # assign master param pointers to p.data. + # We will not trigger data copy here. + for group in self.optim.param_groups: + for p in group['params']: + self.master_params[p].trans_state(TensorState.COMPUTE) + p.data = self.master_params[p].payload + # Now p.data is sharded + # So optimizer states are sharded naturally + + def _write_back_data(self): + # Copy master param data (fp32) to payload of colo_attr (fp16) + # TODO() improve efficiency by gathering tensors into a chunk and transfering + # a chunk. + for group in self.optim.param_groups: + for p in group['params']: + is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded + if not is_param_sharded: + # We use ZeRO-2 here + # The `p.colo_attr.sharded_data_tensor` saves full fp16 param + # But we only have updated fp32 param shard here + # So we first shard full fp16 param and copy fp32 param shard to it + # Then we will gather them + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) + # We have to use `copy_payload` instead of `reset_payload` + # Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16 + + # TODO() optimize this line CPU (fp32) -> GPU (fp16) + p.colo_attr.sharded_data_tensor.reset_payload( + colo_model_tensor_clone(p.half(), torch.cuda.current_device())) + + if not is_param_sharded: + # We gather full fp16 param here + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + p.data = p.colo_attr.sharded_data_tensor.payload + self.master_params[p].trans_state(TensorState.HOLD) p.colo_attr.saved_grad.set_null() diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 6a3faa636..f3124127b 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -10,7 +10,6 @@ class ShardedParamV2(object): def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None: self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) - self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) # This attribute must be initialized in ShardedModel self.offload_grad: bool = False @@ -57,10 +56,6 @@ class ShardedParamV2(object): _update_mem_use(self.sharded_data_tensor.payload) address_set.add(self.sharded_data_tensor.payload.data_ptr()) - if not self.fp16_grad.is_null() and self.fp16_grad.data_ptr() not in address_set: - _update_mem_use(self.fp16_grad.payload) - address_set.add(self.fp16_grad.data_ptr()) - if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set: _update_mem_use(self.saved_grad.payload) address_set.add(self.saved_grad.data_ptr()) diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index af833857f..2d2ae1075 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port): # 4 is size of dummy tensor of param.data assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 - sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half()) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 - assert cuda_mem_use == 2 * 3 * 2 - - sparam.fp16_grad = StatefulTensor(None) sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) sparam.remove_torch_payload() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()