mirror of https://github.com/hpcaitech/ColossalAI
[zero] hijack p.grad in sharded model (#554)
* hijack p.grad in sharded model * polish comments * polish commentspull/558/head
parent
f552b11294
commit
014bac0c49
|
@ -9,7 +9,9 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
|||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
|
||||
|
||||
from colossalai.utils.memory_utils.utils import \
|
||||
colo_model_data_tensor_move_inline
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
|
@ -67,21 +69,6 @@ class ZeroHook(BaseOpHook):
|
|||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
if param.col_attr.bwd_count == 0:
|
||||
# We haven't stored local accumulated grad yet
|
||||
assert param.col_attr.fp32_grad.is_null()
|
||||
|
||||
# Allocate grad fp32 memory space here
|
||||
param.col_attr.fp32_grad.reset_payload(param.grad.data)
|
||||
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
|
||||
param.grad = None
|
||||
else:
|
||||
# We have stored local accumulated grad
|
||||
# The grad here must be locally computed full grad in this backward pass
|
||||
assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape
|
||||
param.col_attr.bwd_count += 1
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
|
|
|
@ -12,20 +12,18 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
|
|||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
|
@ -233,11 +231,11 @@ class ShardedModelV2(nn.Module):
|
|||
if not p.col_attr.param_is_sharded:
|
||||
tensor_list.append(p.col_attr.sharded_data_tensor)
|
||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.col_attr.remove_torch_payload()
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. move sharded param grad payload to param.grad
|
||||
for p in self.module.parameters():
|
||||
p.col_attr.bwd_count = 0
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
|
||||
|
@ -247,14 +245,10 @@ class ShardedModelV2(nn.Module):
|
|||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
# Write grad payload kept by sharded param back to p.grad,
|
||||
# and set p.col_attr.grad to None
|
||||
# As sharded optimizer only update a shard of param,
|
||||
# no matter whether we shard param in sharded model
|
||||
# We have to make sure the grad is a flat tensor shard
|
||||
# If world size == 1 and param is sharded,
|
||||
# the shape `grad` is the same as unsharded param
|
||||
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
||||
# Reduced grad is saved in `p.col_attr.saved_grad`
|
||||
# It can be on CPU or CUDA
|
||||
# It can be fp16 or fp32
|
||||
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
|
||||
if self.reuse_fp16_shard:
|
||||
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
|
||||
else:
|
||||
|
@ -262,13 +256,15 @@ class ShardedModelV2(nn.Module):
|
|||
assert isinstance(grad_fp16_payload, torch.Tensor)
|
||||
if p.col_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(grad_fp16_payload)
|
||||
if not p.col_attr.fp32_grad.is_null():
|
||||
if not p.col_attr.saved_grad.is_null():
|
||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
p.col_attr.fp32_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.fp32_grad.payload))
|
||||
grad_fp16_payload = p.col_attr.fp32_grad.payload
|
||||
p.col_attr.fp32_grad.set_null()
|
||||
# Accumulate grad, saved grad must be fp32
|
||||
p.col_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.col_attr.saved_grad.payload))
|
||||
p.col_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.saved_grad.payload))
|
||||
else:
|
||||
p.col_attr.saved_grad.reset_payload(grad_fp16_payload)
|
||||
|
||||
p.grad.data = grad_fp16_payload
|
||||
p.grad = None
|
||||
p.col_attr.fp16_grad.set_null()
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -5,23 +5,22 @@ from typing import Dict, Optional, Tuple
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils.memory_utils.utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -170,6 +169,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
return cuda_use, cpu_use
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._prepare_grads()
|
||||
self._maybe_move_fp32_shards()
|
||||
|
||||
# unscale grads if scaled
|
||||
|
@ -294,3 +294,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||
p.col_attr.offload_grad = False
|
||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||
|
||||
def _prepare_grads(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
||||
# If we change p.grad directly
|
||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||
# We just set p.data = p.col_attr.saved_grad.payload here
|
||||
p.data = p.col_attr.saved_grad.payload
|
||||
p.grad = p.col_attr.saved_grad.payload
|
||||
p.col_attr.saved_grad.set_null()
|
||||
|
|
|
@ -11,7 +11,7 @@ class ShardedParamV2(object):
|
|||
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
# This attribute must be initialized in ShardedModel
|
||||
self.offload_grad: bool = False
|
||||
|
||||
|
@ -24,11 +24,6 @@ class ShardedParamV2(object):
|
|||
if rm_torch_payload:
|
||||
self.remove_torch_payload()
|
||||
|
||||
# Backward count for handle local grad accumulation
|
||||
# This value will increment by 1 in every pre-bwd hook
|
||||
# And will be reset to 0 in every final-bwd hook
|
||||
self.bwd_count = 0
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
|
||||
|
@ -66,9 +61,9 @@ class ShardedParamV2(object):
|
|||
_update_mem_use(self.fp16_grad.payload)
|
||||
address_set.add(self.fp16_grad.data_ptr())
|
||||
|
||||
if not self.fp32_grad.is_null() and self.fp32_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp32_grad.payload)
|
||||
address_set.add(self.fp32_grad.data_ptr())
|
||||
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.saved_grad.payload)
|
||||
address_set.add(self.saved_grad.data_ptr())
|
||||
|
||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||
_update_mem_use(self.param.data)
|
||||
|
|
|
@ -92,7 +92,8 @@ def check_params(model, zero_model, loose=False):
|
|||
def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
|
|
@ -53,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||
|
||||
# Test get memory usage
|
||||
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||
|
||||
|
@ -69,7 +69,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
assert cuda_mem_use == 2 * 3 * 2
|
||||
|
||||
sparam.fp16_grad = StatefulTensor(None)
|
||||
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.remove_torch_payload()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
|
@ -83,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
assert cuda_mem_use == 0
|
||||
|
||||
# reuse torch grad for sparam
|
||||
sparam.fp32_grad = StatefulTensor(param.grad)
|
||||
sparam.saved_grad = StatefulTensor(param.grad)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
|
|
Loading…
Reference in New Issue