mirror of https://github.com/hpcaitech/ColossalAI
[zero] label state for param fp16 and grad (#551)
parent
92f4224867
commit
f552b11294
|
@ -6,6 +6,7 @@ from colossalai.registry import OPHOOKS
|
|||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
|
||||
|
@ -42,7 +43,13 @@ class ZeroHook(BaseOpHook):
|
|||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
|
@ -65,7 +72,10 @@ class ZeroHook(BaseOpHook):
|
|||
if param.col_attr.bwd_count == 0:
|
||||
# We haven't stored local accumulated grad yet
|
||||
assert param.col_attr.fp32_grad.is_null()
|
||||
|
||||
# Allocate grad fp32 memory space here
|
||||
param.col_attr.fp32_grad.reset_payload(param.grad.data)
|
||||
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
|
||||
param.grad = None
|
||||
else:
|
||||
# We have stored local accumulated grad
|
||||
|
@ -75,12 +85,19 @@ class ZeroHook(BaseOpHook):
|
|||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.remove_torch_payload()
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
|
@ -158,12 +158,25 @@ class ShardedModelV2(nn.Module):
|
|||
f.write(str(self._memstats_collector.non_model_data_cuda_GB))
|
||||
f.write('\n')
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
def _pre_forward_operations(self):
|
||||
if self._iter_cnter == 0 and self._memstats_collector:
|
||||
# the opeartion will affect the flag in ZeroHook
|
||||
# the operation will affect the memory tracer behavior in ZeroHook
|
||||
self._memstats_collector.start_collection()
|
||||
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'col_attr'):
|
||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
def _post_forward_operations(self):
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'col_attr'):
|
||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
self._pre_forward_operations()
|
||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self._post_forward_operations()
|
||||
return outputs
|
||||
|
||||
def backward(self, loss):
|
||||
|
@ -195,9 +208,15 @@ class ShardedModelV2(nn.Module):
|
|||
def _post_backward_operations(self) -> None:
|
||||
"""
|
||||
The method includes operations required to be processed after backward
|
||||
1. update memory tracer.
|
||||
2. flush the gradient in buckets. Reducing partial gradients in each process.
|
||||
3. shard tensors not dealed in the zero hook
|
||||
4. move sharded param grad payload to param.grad
|
||||
"""
|
||||
# 1. update memory tracer.
|
||||
self._update_memstats()
|
||||
|
||||
# 2. flush the gradient in buckets. Reducing partial gradients in each process.
|
||||
if self._require_backward_grad_sync:
|
||||
# Flush any unreduced buckets in the post_backward stream.
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
|
@ -207,44 +226,50 @@ class ShardedModelV2(nn.Module):
|
|||
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.reducer.free()
|
||||
# In case some post bwd hook is not fired
|
||||
# 3. shard tensors not dealed in the zero hook
|
||||
if self.shard_param:
|
||||
tensor_list = []
|
||||
for p in self.module.parameters():
|
||||
if not p.col_attr.param_is_sharded:
|
||||
tensor_list.append(p.col_attr.sharded_data_tensor)
|
||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. move sharded param grad payload to param.grad
|
||||
for p in self.module.parameters():
|
||||
p.col_attr.bwd_count = 0
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
|
||||
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
|
||||
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
|
||||
# sync passes, if desired.
|
||||
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
|
||||
# NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient allreducing between process group.
|
||||
# If _require_backward_grad_sync is True,
|
||||
# p.grad remains the accumulated unsharded gradient from prior no-sync passes.
|
||||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
# Write grad back to p.grad and set p.col_attr.grad to None
|
||||
# Write grad payload kept by sharded param back to p.grad,
|
||||
# and set p.col_attr.grad to None
|
||||
# As sharded optimizer only update a shard of param,
|
||||
# no matter whether we shard param in sharded model
|
||||
# We have to make sure the grad is a flat tensor shard
|
||||
# If world size == 1 and sharded param,
|
||||
# If world size == 1 and param is sharded,
|
||||
# the shape `grad` is the same as unsharded param
|
||||
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
||||
if self.reuse_fp16_shard:
|
||||
grad_payload = p.col_attr.sharded_data_tensor.payload
|
||||
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
|
||||
else:
|
||||
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
|
||||
assert isinstance(grad_payload, torch.Tensor)
|
||||
grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
|
||||
assert isinstance(grad_fp16_payload, torch.Tensor)
|
||||
if p.col_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(grad_payload)
|
||||
colo_model_data_move_to_cpu(grad_fp16_payload)
|
||||
if not p.col_attr.fp32_grad.is_null():
|
||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
p.col_attr.fp32_grad.payload.add_(grad_payload.view_as(p.col_attr.fp32_grad.payload))
|
||||
grad_payload = p.col_attr.fp32_grad.payload
|
||||
p.grad.data = grad_payload
|
||||
p.col_attr.fp32_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.fp32_grad.payload))
|
||||
grad_fp16_payload = p.col_attr.fp32_grad.payload
|
||||
p.col_attr.fp32_grad.set_null()
|
||||
|
||||
p.grad.data = grad_fp16_payload
|
||||
p.col_attr.fp16_grad.set_null()
|
||||
p.col_attr.fp32_grad.set_null()
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
|
|
|
@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
growth_interval: float = 1000,
|
||||
hysteresis: float = 2,
|
||||
max_scale: int = 2**32,
|
||||
use_memory_tracer=False,
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||
|
|
|
@ -8,11 +8,8 @@ from .tensorful_state import StatefulTensor, TensorState
|
|||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
def __init__(self,
|
||||
param: torch.nn.Parameter,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
rm_torch_payload=False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
||||
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
# This attribute must be initialized in ShardedModel
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Optional
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ShardedTensor(StatefulTensor):
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
||||
r"""
|
||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||
"""
|
||||
super().__init__(tensor)
|
||||
self.trans_state(TensorState.HOLD)
|
||||
super().__init__(tensor, state)
|
||||
|
||||
# kept the shape, numel and dtype of the init tensor.
|
||||
self._origin_shape = tensor.shape
|
||||
self._origin_numel = tensor.numel()
|
||||
self._origin_dtype = tensor.dtype
|
||||
|
||||
self._is_sharded = False
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from logging import NullHandler
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -8,22 +8,22 @@ class TensorState(Enum):
|
|||
HOLD = 1
|
||||
HOLD_AFTER_FWD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
COMPUTE = 4
|
||||
|
||||
|
||||
class StatefulTensor(object):
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
Inspired from the paper:
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
||||
def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
if state is not TensorState.FREE:
|
||||
self._payload = tensor
|
||||
else:
|
||||
self._payload = None
|
||||
self._payload = tensor
|
||||
if self._state == TensorState.FREE:
|
||||
assert self._payload is None, f"payload has to None if {self._state}"
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
|
|
|
@ -48,7 +48,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
|
||||
param = torch.nn.Parameter(torch.randn(2, 3))
|
||||
param_ref = deepcopy(param)
|
||||
sparam = ShardedParamV2(param=param, process_group=None)
|
||||
sparam = ShardedParamV2(param=param)
|
||||
|
||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||
|
||||
|
|
Loading…
Reference in New Issue