mirror of https://github.com/hpcaitech/ColossalAI
use double buffer to handle grad
parent
0f5f5dd556
commit
9506a8beb2
|
@ -1,12 +1,14 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
|
@ -62,8 +64,8 @@ class ZeroHook(BaseOpHook):
|
|||
if param.grad is not None:
|
||||
if param.col_attr.bwd_count == 0:
|
||||
# We haven't stored local accumulated grad yet
|
||||
assert param.col_attr.grad is None
|
||||
param.col_attr.grad = param.grad.data
|
||||
assert param.col_attr.fp32_grad is None
|
||||
param.col_attr.fp32_grad = param.grad.data
|
||||
param.grad = None
|
||||
else:
|
||||
# We have stored local accumulated grad
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
# Inserts _post_init_method at the end of init method
|
||||
|
||||
|
@ -154,6 +155,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
if self.shard_param:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
if param.col_attr.grad and self.shard_grad:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
# if param.col_attr.grad and self.shard_grad:
|
||||
# self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from ast import Try
|
||||
import functools
|
||||
from ast import Try
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
|
@ -12,16 +12,17 @@ from colossalai.engine.ophooks import register_ophooks_recursively
|
|||
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.commons.memory import col_cuda_memory_capacity
|
||||
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
|
||||
|
||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
||||
get_gradient_predivide_factor)
|
||||
from colossalai.utils.commons.memory import col_cuda_memory_capacity
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
|
@ -164,8 +165,15 @@ class ShardedModelV2(nn.Module):
|
|||
# If world size == 1 and sharded param,
|
||||
# the shape `grad` is the same as unsharded param
|
||||
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
||||
p.grad.data = p.col_attr.grad.view(-1)
|
||||
p.col_attr.grad = None
|
||||
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
|
||||
if self._cpu_offload:
|
||||
col_move_to_cpu(grad)
|
||||
if p.col_attr.fp32_grad is not None:
|
||||
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
|
||||
grad = p.col_attr.fp32_grad
|
||||
p.grad.data = grad.view(-1)
|
||||
p.col_attr.fp16_grad = None
|
||||
p.col_attr.fp32_grad = None
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
|
@ -216,23 +224,7 @@ class ShardedModelV2(nn.Module):
|
|||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||
|
||||
# Make sure we store fp32 grad
|
||||
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data)
|
||||
|
||||
# Maybe offload
|
||||
# TODO() optimize GPU->CPU bandwidth utilization
|
||||
if self._cpu_offload:
|
||||
col_move_to_cpu(reduced_grad)
|
||||
# reduced_grad.data = reduced_grad.data.cpu()
|
||||
|
||||
if param.col_attr.grad is None:
|
||||
param.col_attr.grad = reduced_grad.data
|
||||
else:
|
||||
# When dp size = 1
|
||||
# param.col_attr.grad is local accumulated grad shard (full but flatten)
|
||||
# But reduced_grad here is full grad
|
||||
# We should call `view_as`
|
||||
param.col_attr.grad.add_(reduced_grad.data.view_as(param.col_attr.grad))
|
||||
param.col_attr.fp16_grad = reduced_grad.data
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])
|
||||
|
|
|
@ -16,7 +16,8 @@ class ShardedParamV2(object):
|
|||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
rm_torch_payload=False) -> None:
|
||||
self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
||||
self._grad_sharded_tensor: Optional[torch.Tensor] = None
|
||||
self.fp16_grad: Optional[torch.Tensor] = None
|
||||
self.fp32_grad: Optional[torch.Tensor] = None
|
||||
|
||||
# make sure the shared param is the only owner of payload
|
||||
# The param.data maybe used to init the other part of the model.
|
||||
|
@ -39,14 +40,6 @@ class ShardedParamV2(object):
|
|||
def data(self):
|
||||
return self._data_sharded_tensor
|
||||
|
||||
@property
|
||||
def grad(self):
|
||||
return self._grad_sharded_tensor
|
||||
|
||||
@grad.setter
|
||||
def grad(self, t: torch.Tensor):
|
||||
self._grad_sharded_tensor = t
|
||||
|
||||
@property
|
||||
def param_is_sharded(self):
|
||||
return self._data_sharded_tensor.is_sharded
|
||||
|
|
Loading…
Reference in New Issue