From 014bac0c493a5882bed0034a6d3a80b3d408b06c Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 30 Mar 2022 18:14:50 +0800 Subject: [PATCH] [zero] hijack p.grad in sharded model (#554) * hijack p.grad in sharded model * polish comments * polish comments --- colossalai/engine/ophooks/zero_hook.py | 19 ++---------- .../zero/sharded_model/sharded_model_v2.py | 30 ++++++++----------- .../zero/sharded_optim/sharded_optim_v2.py | 29 ++++++++++++------ .../zero/sharded_param/sharded_param.py | 13 +++----- tests/test_zero_data_parallel/common.py | 3 +- .../test_shard_param.py | 6 ++-- 6 files changed, 45 insertions(+), 55 deletions(-) diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index d30b786bc..a3e40ba8b 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -9,7 +9,9 @@ from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_param.tensorful_state import TensorState from ._base_ophook import BaseOpHook -from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline + +from colossalai.utils.memory_utils.utils import \ + colo_model_data_tensor_move_inline @OPHOOKS.register_module @@ -67,21 +69,6 @@ class ZeroHook(BaseOpHook): for param in module.parameters(recurse=False): colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) param.data = param.col_attr.sharded_data_tensor.payload - # Store local accumulated grad shard - if param.grad is not None: - if param.col_attr.bwd_count == 0: - # We haven't stored local accumulated grad yet - assert param.col_attr.fp32_grad.is_null() - - # Allocate grad fp32 memory space here - param.col_attr.fp32_grad.reset_payload(param.grad.data) - # TODO(jiaruifang) we should set grad fp16 state to HOLD here. - param.grad = None - else: - # We have stored local accumulated grad - # The grad here must be locally computed full grad in this backward pass - assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape - param.col_attr.bwd_count += 1 if self._memstarts_collector: self._memstarts_collector.sample_memstats() diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 846bfc016..c4aac0001 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -12,20 +12,18 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device -from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu) from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer +from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState) from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) -from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState class ShardedModelV2(nn.Module): @@ -233,11 +231,11 @@ class ShardedModelV2(nn.Module): if not p.col_attr.param_is_sharded: tensor_list.append(p.col_attr.sharded_data_tensor) p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + p.col_attr.remove_torch_payload() self.shard_strategy.shard(tensor_list, self.process_group) # 4. move sharded param grad payload to param.grad for p in self.module.parameters(): - p.col_attr.bwd_count = 0 if not p.requires_grad: continue # Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass. @@ -247,14 +245,10 @@ class ShardedModelV2(nn.Module): # We also allows to interleave no-sync pass with sync passes, if desired. if not self._require_backward_grad_sync: continue - # Write grad payload kept by sharded param back to p.grad, - # and set p.col_attr.grad to None - # As sharded optimizer only update a shard of param, - # no matter whether we shard param in sharded model - # We have to make sure the grad is a flat tensor shard - # If world size == 1 and param is sharded, - # the shape `grad` is the same as unsharded param - # So we can just use `view(-1)` to ensure grad is a flat tensor shard + # Reduced grad is saved in `p.col_attr.saved_grad` + # It can be on CPU or CUDA + # It can be fp16 or fp32 + # We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`. if self.reuse_fp16_shard: grad_fp16_payload = p.col_attr.sharded_data_tensor.payload else: @@ -262,13 +256,15 @@ class ShardedModelV2(nn.Module): assert isinstance(grad_fp16_payload, torch.Tensor) if p.col_attr.offload_grad: colo_model_data_move_to_cpu(grad_fp16_payload) - if not p.col_attr.fp32_grad.is_null(): + if not p.col_attr.saved_grad.is_null(): assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' - p.col_attr.fp32_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.fp32_grad.payload)) - grad_fp16_payload = p.col_attr.fp32_grad.payload - p.col_attr.fp32_grad.set_null() + # Accumulate grad, saved grad must be fp32 + p.col_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.col_attr.saved_grad.payload)) + p.col_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.saved_grad.payload)) + else: + p.col_attr.saved_grad.reset_payload(grad_fp16_payload) - p.grad.data = grad_fp16_payload + p.grad = None p.col_attr.fp16_grad.set_null() @torch.no_grad() diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index fa3b2daa4..86fa7aadd 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -5,23 +5,22 @@ from typing import Dict, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter -from torch.optim import Optimizer - from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.utils.memory_utils.utils import (colo_model_tensor_clone, colo_tensor_mem_usage) +from colossalai.utils.memory_tracer.model_data_memtracer import \ + GLOBAL_MODEL_DATA_TRACER +from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone, + colo_tensor_mem_usage) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.sharded_optim._utils import has_inf_or_nan -from colossalai.zero.sharded_optim._utils import has_inf_or_nan -from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage -from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter +from torch.optim import Optimizer class OptimState(Enum): @@ -170,6 +169,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): return cuda_use, cpu_use def step(self, *args, **kwargs): + self._prepare_grads() self._maybe_move_fp32_shards() # unscale grads if scaled @@ -294,3 +294,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer): p.grad.data = p.grad.data.to(torch.cuda.current_device()) p.col_attr.offload_grad = False fp32_shards_used_cuda_margin_mem += shard_mem + + def _prepare_grads(self): + for group in self.optim.param_groups: + for p in group['params']: + # FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation + # If we change p.grad directly + # it may raise error because of different shape/dtype/device of p.data and p.grad + # We just set p.data = p.col_attr.saved_grad.payload here + p.data = p.col_attr.saved_grad.payload + p.grad = p.col_attr.saved_grad.payload + p.col_attr.saved_grad.set_null() diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 71e8030ac..6a3faa636 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -11,7 +11,7 @@ class ShardedParamV2(object): def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None: self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) - self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) + self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) # This attribute must be initialized in ShardedModel self.offload_grad: bool = False @@ -24,11 +24,6 @@ class ShardedParamV2(object): if rm_torch_payload: self.remove_torch_payload() - # Backward count for handle local grad accumulation - # This value will increment by 1 in every pre-bwd hook - # And will be reset to 0 in every final-bwd hook - self.bwd_count = 0 - def remove_torch_payload(self): self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device) @@ -66,9 +61,9 @@ class ShardedParamV2(object): _update_mem_use(self.fp16_grad.payload) address_set.add(self.fp16_grad.data_ptr()) - if not self.fp32_grad.is_null() and self.fp32_grad.data_ptr() not in address_set: - _update_mem_use(self.fp32_grad.payload) - address_set.add(self.fp32_grad.data_ptr()) + if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set: + _update_mem_use(self.saved_grad.payload) + address_set.add(self.saved_grad.data_ptr()) if self.param.data is not None and self.param.data.data_ptr() not in address_set: _update_mem_use(self.param.data) diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 70166e121..cdac165ee 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -92,7 +92,8 @@ def check_params(model, zero_model, loose=False): def check_grads_padding(model, zero_model, loose=False): rank = dist.get_rank() for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_grad = zero_p.grad.clone().to(p.device) + # zero_grad = zero_p.grad.clone().to(p.device) + zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device) chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) if rank >= len(chunks): continue diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 56c77fdee..af833857f 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -53,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port): allclose(sparam.sharded_data_tensor.payload, param_ref.data) # Test get memory usage - sparam.fp32_grad = StatefulTensor(torch.randn(2, 3)) + sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" @@ -69,7 +69,7 @@ def _run_shard_param_v2(rank, world_size, port): assert cuda_mem_use == 2 * 3 * 2 sparam.fp16_grad = StatefulTensor(None) - sparam.fp32_grad = StatefulTensor(torch.randn(2, 3)) + sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) sparam.remove_torch_payload() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 @@ -83,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port): assert cuda_mem_use == 0 # reuse torch grad for sparam - sparam.fp32_grad = StatefulTensor(param.grad) + sparam.saved_grad = StatefulTensor(param.grad) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2 assert cuda_mem_use == 0