[zero] label state for param fp16 and grad (#551)

pull/553/head
Jiarui Fang 2022-03-30 15:57:46 +08:00 committed by GitHub
parent 92f4224867
commit f552b11294
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 75 additions and 39 deletions

View File

@ -6,6 +6,7 @@ from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
from ._base_ophook import BaseOpHook
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
@ -42,7 +43,13 @@ class ZeroHook(BaseOpHook):
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr')
@ -65,7 +72,10 @@ class ZeroHook(BaseOpHook):
if param.col_attr.bwd_count == 0:
# We haven't stored local accumulated grad yet
assert param.col_attr.fp32_grad.is_null()
# Allocate grad fp32 memory space here
param.col_attr.fp32_grad.reset_payload(param.grad.data)
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
param.grad = None
else:
# We have stored local accumulated grad
@ -75,12 +85,19 @@ class ZeroHook(BaseOpHook):
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
param.col_attr.remove_torch_payload()

View File

@ -25,7 +25,7 @@ from torch.nn.parameter import Parameter
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
class ShardedModelV2(nn.Module):
@ -158,12 +158,25 @@ class ShardedModelV2(nn.Module):
f.write(str(self._memstats_collector.non_model_data_cuda_GB))
f.write('\n')
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
def _pre_forward_operations(self):
if self._iter_cnter == 0 and self._memstats_collector:
# the opeartion will affect the flag in ZeroHook
# the operation will affect the memory tracer behavior in ZeroHook
self._memstats_collector.start_collection()
for p in self.module.parameters():
if hasattr(p, 'col_attr'):
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def _post_forward_operations(self):
for p in self.module.parameters():
if hasattr(p, 'col_attr'):
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations()
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs)
self._post_forward_operations()
return outputs
def backward(self, loss):
@ -195,9 +208,15 @@ class ShardedModelV2(nn.Module):
def _post_backward_operations(self) -> None:
"""
The method includes operations required to be processed after backward
1. update memory tracer.
2. flush the gradient in buckets. Reducing partial gradients in each process.
3. shard tensors not dealed in the zero hook
4. move sharded param grad payload to param.grad
"""
# 1. update memory tracer.
self._update_memstats()
# 2. flush the gradient in buckets. Reducing partial gradients in each process.
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream):
@ -207,44 +226,50 @@ class ShardedModelV2(nn.Module):
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
self.reducer.free()
# In case some post bwd hook is not fired
# 3. shard tensors not dealed in the zero hook
if self.shard_param:
tensor_list = []
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
tensor_list.append(p.col_attr.sharded_data_tensor)
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
self.shard_strategy.shard(tensor_list, self.process_group)
# 4. move sharded param grad payload to param.grad
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
# NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient allreducing between process group.
# If _require_backward_grad_sync is True,
# p.grad remains the accumulated unsharded gradient from prior no-sync passes.
# We also allows to interleave no-sync pass with sync passes, if desired.
if not self._require_backward_grad_sync:
continue
# Write grad back to p.grad and set p.col_attr.grad to None
# Write grad payload kept by sharded param back to p.grad,
# and set p.col_attr.grad to None
# As sharded optimizer only update a shard of param,
# no matter whether we shard param in sharded model
# We have to make sure the grad is a flat tensor shard
# If world size == 1 and sharded param,
# If world size == 1 and param is sharded,
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
if self.reuse_fp16_shard:
grad_payload = p.col_attr.sharded_data_tensor.payload
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
else:
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
assert isinstance(grad_payload, torch.Tensor)
grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
assert isinstance(grad_fp16_payload, torch.Tensor)
if p.col_attr.offload_grad:
colo_model_data_move_to_cpu(grad_payload)
colo_model_data_move_to_cpu(grad_fp16_payload)
if not p.col_attr.fp32_grad.is_null():
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.payload.add_(grad_payload.view_as(p.col_attr.fp32_grad.payload))
grad_payload = p.col_attr.fp32_grad.payload
p.grad.data = grad_payload
p.col_attr.fp32_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.fp32_grad.payload))
grad_fp16_payload = p.col_attr.fp32_grad.payload
p.col_attr.fp32_grad.set_null()
p.grad.data = grad_fp16_payload
p.col_attr.fp16_grad.set_null()
p.col_attr.fp32_grad.set_null()
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:

View File

@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval: float = 1000,
hysteresis: float = 2,
max_scale: int = 2**32,
use_memory_tracer=False,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'

View File

@ -8,11 +8,8 @@ from .tensorful_state import StatefulTensor, TensorState
class ShardedParamV2(object):
def __init__(self,
param: torch.nn.Parameter,
process_group: Optional[dist.ProcessGroup] = None,
rm_torch_payload=False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel

View File

@ -1,22 +1,20 @@
import torch
import torch.distributed as dist
from typing import Optional
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from typing import Optional
class ShardedTensor(StatefulTensor):
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
r"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
super().__init__(tensor)
self.trans_state(TensorState.HOLD)
super().__init__(tensor, state)
# kept the shape, numel and dtype of the init tensor.
self._origin_shape = tensor.shape
self._origin_numel = tensor.numel()
self._origin_dtype = tensor.dtype
self._is_sharded = False
@property

View File

@ -1,5 +1,5 @@
from enum import Enum
from logging import NullHandler
from typing import Optional
import torch
@ -8,22 +8,22 @@ class TensorState(Enum):
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
if state is not TensorState.FREE:
self._payload = tensor
else:
self._payload = None
self._payload = tensor
if self._state == TensorState.FREE:
assert self._payload is None, f"payload has to None if {self._state}"
def data_ptr(self):
if self._payload is None:

View File

@ -48,7 +48,7 @@ def _run_shard_param_v2(rank, world_size, port):
param = torch.nn.Parameter(torch.randn(2, 3))
param_ref = deepcopy(param)
sparam = ShardedParamV2(param=param, process_group=None)
sparam = ShardedParamV2(param=param)
allclose(sparam.sharded_data_tensor.payload, param_ref.data)