[zero] fix grad offload (#528)

* [zero] fix grad offload

* polish code
pull/529/head^2
Jiarui Fang 2022-03-25 18:23:25 +08:00 committed by GitHub
parent 105c5301c3
commit 05e33b2578
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 7 deletions

View File

@ -114,3 +114,20 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
t_payload.data = t_payload.data.cpu()
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
"""
Clone a model data tensor
Args:
t (Union[ShardedTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
t_payload = t.payload if isinstance(t, ShardedTensor) else t
ret = t_payload.to(target_device)
GLOBAL_MODEL_DATA_TRACER.add_tensor(ret)
return ret

View File

@ -11,7 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity, colo_model_tensor_clone
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
@ -198,16 +198,16 @@ class ShardedModelV2(nn.Module):
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
if self.reuse_fp16_shard:
grad = p.col_attr.sharded_data_tensor.payload
grad_payload = p.col_attr.sharded_data_tensor.payload
else:
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if p.col_attr.offload_grad:
colo_model_data_move_to_cpu(grad)
grad_payload = colo_model_tensor_clone(grad_payload, torch.device('cpu'))
if p.col_attr.fp32_grad is not None:
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
grad = p.col_attr.fp32_grad
p.grad.data = grad
p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad))
grad_payload = p.col_attr.fp32_grad
p.grad.data = grad_payload
p.col_attr.fp16_grad = None
p.col_attr.fp32_grad = None

View File

@ -9,6 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from torch import Tensor
@ -217,6 +218,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
for group in self.param_groups:
for p in group['params']:
GLOBAL_MODEL_DATA_TRACER.delete_tensor(p.grad)
self.optim.zero_grad(set_to_none=True)
def sync_grad(self):