mirror of https://github.com/hpcaitech/ColossalAI
[zero] hijack p.grad in sharded model (#554)
* hijack p.grad in sharded model * polish comments * polish commentspull/558/head
parent
f552b11294
commit
014bac0c49
|
@ -9,7 +9,9 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
|
|
||||||
from ._base_ophook import BaseOpHook
|
from ._base_ophook import BaseOpHook
|
||||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
|
|
||||||
|
from colossalai.utils.memory_utils.utils import \
|
||||||
|
colo_model_data_tensor_move_inline
|
||||||
|
|
||||||
|
|
||||||
@OPHOOKS.register_module
|
@OPHOOKS.register_module
|
||||||
|
@ -67,21 +69,6 @@ class ZeroHook(BaseOpHook):
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||||
param.data = param.col_attr.sharded_data_tensor.payload
|
param.data = param.col_attr.sharded_data_tensor.payload
|
||||||
# Store local accumulated grad shard
|
|
||||||
if param.grad is not None:
|
|
||||||
if param.col_attr.bwd_count == 0:
|
|
||||||
# We haven't stored local accumulated grad yet
|
|
||||||
assert param.col_attr.fp32_grad.is_null()
|
|
||||||
|
|
||||||
# Allocate grad fp32 memory space here
|
|
||||||
param.col_attr.fp32_grad.reset_payload(param.grad.data)
|
|
||||||
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
|
|
||||||
param.grad = None
|
|
||||||
else:
|
|
||||||
# We have stored local accumulated grad
|
|
||||||
# The grad here must be locally computed full grad in this backward pass
|
|
||||||
assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape
|
|
||||||
param.col_attr.bwd_count += 1
|
|
||||||
if self._memstarts_collector:
|
if self._memstarts_collector:
|
||||||
self._memstarts_collector.sample_memstats()
|
self._memstarts_collector.sample_memstats()
|
||||||
|
|
||||||
|
|
|
@ -12,20 +12,18 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
|
||||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
|
||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
|
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
|
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
|
@ -233,11 +231,11 @@ class ShardedModelV2(nn.Module):
|
||||||
if not p.col_attr.param_is_sharded:
|
if not p.col_attr.param_is_sharded:
|
||||||
tensor_list.append(p.col_attr.sharded_data_tensor)
|
tensor_list.append(p.col_attr.sharded_data_tensor)
|
||||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||||
|
p.col_attr.remove_torch_payload()
|
||||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||||
|
|
||||||
# 4. move sharded param grad payload to param.grad
|
# 4. move sharded param grad payload to param.grad
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
p.col_attr.bwd_count = 0
|
|
||||||
if not p.requires_grad:
|
if not p.requires_grad:
|
||||||
continue
|
continue
|
||||||
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
|
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
|
||||||
|
@ -247,14 +245,10 @@ class ShardedModelV2(nn.Module):
|
||||||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||||
if not self._require_backward_grad_sync:
|
if not self._require_backward_grad_sync:
|
||||||
continue
|
continue
|
||||||
# Write grad payload kept by sharded param back to p.grad,
|
# Reduced grad is saved in `p.col_attr.saved_grad`
|
||||||
# and set p.col_attr.grad to None
|
# It can be on CPU or CUDA
|
||||||
# As sharded optimizer only update a shard of param,
|
# It can be fp16 or fp32
|
||||||
# no matter whether we shard param in sharded model
|
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
|
||||||
# We have to make sure the grad is a flat tensor shard
|
|
||||||
# If world size == 1 and param is sharded,
|
|
||||||
# the shape `grad` is the same as unsharded param
|
|
||||||
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
|
||||||
if self.reuse_fp16_shard:
|
if self.reuse_fp16_shard:
|
||||||
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
|
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
|
||||||
else:
|
else:
|
||||||
|
@ -262,13 +256,15 @@ class ShardedModelV2(nn.Module):
|
||||||
assert isinstance(grad_fp16_payload, torch.Tensor)
|
assert isinstance(grad_fp16_payload, torch.Tensor)
|
||||||
if p.col_attr.offload_grad:
|
if p.col_attr.offload_grad:
|
||||||
colo_model_data_move_to_cpu(grad_fp16_payload)
|
colo_model_data_move_to_cpu(grad_fp16_payload)
|
||||||
if not p.col_attr.fp32_grad.is_null():
|
if not p.col_attr.saved_grad.is_null():
|
||||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||||
p.col_attr.fp32_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.fp32_grad.payload))
|
# Accumulate grad, saved grad must be fp32
|
||||||
grad_fp16_payload = p.col_attr.fp32_grad.payload
|
p.col_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.col_attr.saved_grad.payload))
|
||||||
p.col_attr.fp32_grad.set_null()
|
p.col_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.saved_grad.payload))
|
||||||
|
else:
|
||||||
|
p.col_attr.saved_grad.reset_payload(grad_fp16_payload)
|
||||||
|
|
||||||
p.grad.data = grad_fp16_payload
|
p.grad = None
|
||||||
p.col_attr.fp16_grad.set_null()
|
p.col_attr.fp16_grad.set_null()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -5,23 +5,22 @@ from typing import Dict, Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from colossalai.utils.memory_utils.utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
|
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone,
|
||||||
|
colo_tensor_mem_usage)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
from torch import Tensor
|
||||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
|
from torch.distributed import ProcessGroup
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from torch.nn.parameter import Parameter
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
class OptimState(Enum):
|
class OptimState(Enum):
|
||||||
|
@ -170,6 +169,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
return cuda_use, cpu_use
|
return cuda_use, cpu_use
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
|
self._prepare_grads()
|
||||||
self._maybe_move_fp32_shards()
|
self._maybe_move_fp32_shards()
|
||||||
|
|
||||||
# unscale grads if scaled
|
# unscale grads if scaled
|
||||||
|
@ -294,3 +294,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||||
p.col_attr.offload_grad = False
|
p.col_attr.offload_grad = False
|
||||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||||
|
|
||||||
|
def _prepare_grads(self):
|
||||||
|
for group in self.optim.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
||||||
|
# If we change p.grad directly
|
||||||
|
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||||
|
# We just set p.data = p.col_attr.saved_grad.payload here
|
||||||
|
p.data = p.col_attr.saved_grad.payload
|
||||||
|
p.grad = p.col_attr.saved_grad.payload
|
||||||
|
p.col_attr.saved_grad.set_null()
|
||||||
|
|
|
@ -11,7 +11,7 @@ class ShardedParamV2(object):
|
||||||
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
||||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||||
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||||
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||||
# This attribute must be initialized in ShardedModel
|
# This attribute must be initialized in ShardedModel
|
||||||
self.offload_grad: bool = False
|
self.offload_grad: bool = False
|
||||||
|
|
||||||
|
@ -24,11 +24,6 @@ class ShardedParamV2(object):
|
||||||
if rm_torch_payload:
|
if rm_torch_payload:
|
||||||
self.remove_torch_payload()
|
self.remove_torch_payload()
|
||||||
|
|
||||||
# Backward count for handle local grad accumulation
|
|
||||||
# This value will increment by 1 in every pre-bwd hook
|
|
||||||
# And will be reset to 0 in every final-bwd hook
|
|
||||||
self.bwd_count = 0
|
|
||||||
|
|
||||||
def remove_torch_payload(self):
|
def remove_torch_payload(self):
|
||||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||||
|
|
||||||
|
@ -66,9 +61,9 @@ class ShardedParamV2(object):
|
||||||
_update_mem_use(self.fp16_grad.payload)
|
_update_mem_use(self.fp16_grad.payload)
|
||||||
address_set.add(self.fp16_grad.data_ptr())
|
address_set.add(self.fp16_grad.data_ptr())
|
||||||
|
|
||||||
if not self.fp32_grad.is_null() and self.fp32_grad.data_ptr() not in address_set:
|
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
||||||
_update_mem_use(self.fp32_grad.payload)
|
_update_mem_use(self.saved_grad.payload)
|
||||||
address_set.add(self.fp32_grad.data_ptr())
|
address_set.add(self.saved_grad.data_ptr())
|
||||||
|
|
||||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||||
_update_mem_use(self.param.data)
|
_update_mem_use(self.param.data)
|
||||||
|
|
|
@ -92,7 +92,8 @@ def check_params(model, zero_model, loose=False):
|
||||||
def check_grads_padding(model, zero_model, loose=False):
|
def check_grads_padding(model, zero_model, loose=False):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||||
zero_grad = zero_p.grad.clone().to(p.device)
|
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||||
|
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
|
||||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||||
if rank >= len(chunks):
|
if rank >= len(chunks):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -53,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||||
|
|
||||||
# Test get memory usage
|
# Test get memory usage
|
||||||
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
assert cuda_mem_use == 2 * 3 * 2
|
assert cuda_mem_use == 2 * 3 * 2
|
||||||
|
|
||||||
sparam.fp16_grad = StatefulTensor(None)
|
sparam.fp16_grad = StatefulTensor(None)
|
||||||
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||||
sparam.remove_torch_payload()
|
sparam.remove_torch_payload()
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||||
|
@ -83,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
assert cuda_mem_use == 0
|
assert cuda_mem_use == 0
|
||||||
|
|
||||||
# reuse torch grad for sparam
|
# reuse torch grad for sparam
|
||||||
sparam.fp32_grad = StatefulTensor(param.grad)
|
sparam.saved_grad = StatefulTensor(param.grad)
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||||
assert cuda_mem_use == 0
|
assert cuda_mem_use == 0
|
||||||
|
|
Loading…
Reference in New Issue