[zero] hijack p.grad in sharded model (#554)

* hijack p.grad in sharded model

* polish comments

* polish comments
pull/558/head
ver217 2022-03-30 18:14:50 +08:00 committed by GitHub
parent f552b11294
commit 014bac0c49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 55 deletions

View File

@ -9,7 +9,9 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
from ._base_ophook import BaseOpHook
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
from colossalai.utils.memory_utils.utils import \
colo_model_data_tensor_move_inline
@OPHOOKS.register_module
@ -67,21 +69,6 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
param.data = param.col_attr.sharded_data_tensor.payload
# Store local accumulated grad shard
if param.grad is not None:
if param.col_attr.bwd_count == 0:
# We haven't stored local accumulated grad yet
assert param.col_attr.fp32_grad.is_null()
# Allocate grad fp32 memory space here
param.col_attr.fp32_grad.reset_payload(param.grad.data)
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
param.grad = None
else:
# We have stored local accumulated grad
# The grad here must be locally computed full grad in this backward pass
assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape
param.col_attr.bwd_count += 1
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()

View File

@ -12,20 +12,18 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
class ShardedModelV2(nn.Module):
@ -233,11 +231,11 @@ class ShardedModelV2(nn.Module):
if not p.col_attr.param_is_sharded:
tensor_list.append(p.col_attr.sharded_data_tensor)
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
p.col_attr.remove_torch_payload()
self.shard_strategy.shard(tensor_list, self.process_group)
# 4. move sharded param grad payload to param.grad
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
continue
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
@ -247,14 +245,10 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired.
if not self._require_backward_grad_sync:
continue
# Write grad payload kept by sharded param back to p.grad,
# and set p.col_attr.grad to None
# As sharded optimizer only update a shard of param,
# no matter whether we shard param in sharded model
# We have to make sure the grad is a flat tensor shard
# If world size == 1 and param is sharded,
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
# Reduced grad is saved in `p.col_attr.saved_grad`
# It can be on CPU or CUDA
# It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if self.reuse_fp16_shard:
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
else:
@ -262,13 +256,15 @@ class ShardedModelV2(nn.Module):
assert isinstance(grad_fp16_payload, torch.Tensor)
if p.col_attr.offload_grad:
colo_model_data_move_to_cpu(grad_fp16_payload)
if not p.col_attr.fp32_grad.is_null():
if not p.col_attr.saved_grad.is_null():
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.fp32_grad.payload))
grad_fp16_payload = p.col_attr.fp32_grad.payload
p.col_attr.fp32_grad.set_null()
# Accumulate grad, saved grad must be fp32
p.col_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.col_attr.saved_grad.payload))
p.col_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.saved_grad.payload))
else:
p.col_attr.saved_grad.reset_payload(grad_fp16_payload)
p.grad.data = grad_fp16_payload
p.grad = None
p.col_attr.fp16_grad.set_null()
@torch.no_grad()

View File

@ -5,23 +5,22 @@ from typing import Dict, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_utils.utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
class OptimState(Enum):
@ -170,6 +169,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
return cuda_use, cpu_use
def step(self, *args, **kwargs):
self._prepare_grads()
self._maybe_move_fp32_shards()
# unscale grads if scaled
@ -294,3 +294,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.col_attr.offload_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem
def _prepare_grads(self):
for group in self.optim.param_groups:
for p in group['params']:
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.col_attr.saved_grad.payload here
p.data = p.col_attr.saved_grad.payload
p.grad = p.col_attr.saved_grad.payload
p.col_attr.saved_grad.set_null()

View File

@ -11,7 +11,7 @@ class ShardedParamV2(object):
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel
self.offload_grad: bool = False
@ -24,11 +24,6 @@ class ShardedParamV2(object):
if rm_torch_payload:
self.remove_torch_payload()
# Backward count for handle local grad accumulation
# This value will increment by 1 in every pre-bwd hook
# And will be reset to 0 in every final-bwd hook
self.bwd_count = 0
def remove_torch_payload(self):
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
@ -66,9 +61,9 @@ class ShardedParamV2(object):
_update_mem_use(self.fp16_grad.payload)
address_set.add(self.fp16_grad.data_ptr())
if not self.fp32_grad.is_null() and self.fp32_grad.data_ptr() not in address_set:
_update_mem_use(self.fp32_grad.payload)
address_set.add(self.fp32_grad.data_ptr())
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
_update_mem_use(self.saved_grad.payload)
address_set.add(self.saved_grad.data_ptr())
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
_update_mem_use(self.param.data)

View File

@ -92,7 +92,8 @@ def check_params(model, zero_model, loose=False):
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
# zero_grad = zero_p.grad.clone().to(p.device)
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue

View File

@ -53,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
# Test get memory usage
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
@ -69,7 +69,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert cuda_mem_use == 2 * 3 * 2
sparam.fp16_grad = StatefulTensor(None)
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
@ -83,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert cuda_mem_use == 0
# reuse torch grad for sparam
sparam.fp32_grad = StatefulTensor(param.grad)
sparam.saved_grad = StatefulTensor(param.grad)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0