[zero] label state for param fp16 and grad (#551)

pull/553/head
Jiarui Fang 2022-03-30 15:57:46 +08:00 committed by GitHub
parent 92f4224867
commit f552b11294
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 75 additions and 39 deletions

View File

@ -6,6 +6,7 @@ from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
from ._base_ophook import BaseOpHook from ._base_ophook import BaseOpHook
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
@ -42,7 +43,13 @@ class ZeroHook(BaseOpHook):
if self._memstarts_collector: if self._memstarts_collector:
self._memstarts_collector.sample_memstats() self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_fwd_exec(self, module: torch.nn.Module, *args): def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
tensor_list = [] tensor_list = []
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
@ -65,7 +72,10 @@ class ZeroHook(BaseOpHook):
if param.col_attr.bwd_count == 0: if param.col_attr.bwd_count == 0:
# We haven't stored local accumulated grad yet # We haven't stored local accumulated grad yet
assert param.col_attr.fp32_grad.is_null() assert param.col_attr.fp32_grad.is_null()
# Allocate grad fp32 memory space here
param.col_attr.fp32_grad.reset_payload(param.grad.data) param.col_attr.fp32_grad.reset_payload(param.grad.data)
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
param.grad = None param.grad = None
else: else:
# We have stored local accumulated grad # We have stored local accumulated grad
@ -75,12 +85,19 @@ class ZeroHook(BaseOpHook):
if self._memstarts_collector: if self._memstarts_collector:
self._memstarts_collector.sample_memstats() self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_bwd_exec(self, module: torch.nn.Module, input): def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
tensor_list = [] tensor_list = []
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor) tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group) self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.col_attr.remove_torch_payload() param.col_attr.remove_torch_payload()

View File

@ -25,7 +25,7 @@ from torch.nn.parameter import Parameter
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor) get_gradient_predivide_factor)
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
@ -158,12 +158,25 @@ class ShardedModelV2(nn.Module):
f.write(str(self._memstats_collector.non_model_data_cuda_GB)) f.write(str(self._memstats_collector.non_model_data_cuda_GB))
f.write('\n') f.write('\n')
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: def _pre_forward_operations(self):
if self._iter_cnter == 0 and self._memstats_collector: if self._iter_cnter == 0 and self._memstats_collector:
# the opeartion will affect the flag in ZeroHook # the operation will affect the memory tracer behavior in ZeroHook
self._memstats_collector.start_collection() self._memstats_collector.start_collection()
for p in self.module.parameters():
if hasattr(p, 'col_attr'):
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def _post_forward_operations(self):
for p in self.module.parameters():
if hasattr(p, 'col_attr'):
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations()
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
self._post_forward_operations()
return outputs return outputs
def backward(self, loss): def backward(self, loss):
@ -195,9 +208,15 @@ class ShardedModelV2(nn.Module):
def _post_backward_operations(self) -> None: def _post_backward_operations(self) -> None:
""" """
The method includes operations required to be processed after backward The method includes operations required to be processed after backward
1. update memory tracer.
2. flush the gradient in buckets. Reducing partial gradients in each process.
3. shard tensors not dealed in the zero hook
4. move sharded param grad payload to param.grad
""" """
# 1. update memory tracer.
self._update_memstats() self._update_memstats()
# 2. flush the gradient in buckets. Reducing partial gradients in each process.
if self._require_backward_grad_sync: if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream. # Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream): with torch.cuda.stream(self.comm_stream):
@ -207,45 +226,51 @@ class ShardedModelV2(nn.Module):
# Wait for the non-blocking GPU -> CPU grad transfers to finish. # Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
self.reducer.free() self.reducer.free()
# In case some post bwd hook is not fired # 3. shard tensors not dealed in the zero hook
if self.shard_param: if self.shard_param:
tensor_list = [] tensor_list = []
for p in self.module.parameters(): for p in self.module.parameters():
if not p.col_attr.param_is_sharded: if not p.col_attr.param_is_sharded:
tensor_list.append(p.col_attr.sharded_data_tensor) tensor_list.append(p.col_attr.sharded_data_tensor)
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
self.shard_strategy.shard(tensor_list, self.process_group) self.shard_strategy.shard(tensor_list, self.process_group)
# 4. move sharded param grad payload to param.grad
for p in self.module.parameters(): for p in self.module.parameters():
p.col_attr.bwd_count = 0 p.col_attr.bwd_count = 0
if not p.requires_grad: if not p.requires_grad:
continue continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad # Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard # NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient allreducing between process group.
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and # If _require_backward_grad_sync is True,
# sync passes, if desired. # p.grad remains the accumulated unsharded gradient from prior no-sync passes.
# We also allows to interleave no-sync pass with sync passes, if desired.
if not self._require_backward_grad_sync: if not self._require_backward_grad_sync:
continue continue
# Write grad back to p.grad and set p.col_attr.grad to None # Write grad payload kept by sharded param back to p.grad,
# and set p.col_attr.grad to None
# As sharded optimizer only update a shard of param, # As sharded optimizer only update a shard of param,
# no matter whether we shard param in sharded model # no matter whether we shard param in sharded model
# We have to make sure the grad is a flat tensor shard # We have to make sure the grad is a flat tensor shard
# If world size == 1 and sharded param, # If world size == 1 and param is sharded,
# the shape `grad` is the same as unsharded param # the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard # So we can just use `view(-1)` to ensure grad is a flat tensor shard
if self.reuse_fp16_shard: if self.reuse_fp16_shard:
grad_payload = p.col_attr.sharded_data_tensor.payload grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
else: else:
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload) grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
assert isinstance(grad_payload, torch.Tensor) assert isinstance(grad_fp16_payload, torch.Tensor)
if p.col_attr.offload_grad: if p.col_attr.offload_grad:
colo_model_data_move_to_cpu(grad_payload) colo_model_data_move_to_cpu(grad_fp16_payload)
if not p.col_attr.fp32_grad.is_null(): if not p.col_attr.fp32_grad.is_null():
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.payload.add_(grad_payload.view_as(p.col_attr.fp32_grad.payload)) p.col_attr.fp32_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.fp32_grad.payload))
grad_payload = p.col_attr.fp32_grad.payload grad_fp16_payload = p.col_attr.fp32_grad.payload
p.grad.data = grad_payload
p.col_attr.fp16_grad.set_null()
p.col_attr.fp32_grad.set_null() p.col_attr.fp32_grad.set_null()
p.grad.data = grad_fp16_payload
p.col_attr.fp16_grad.set_null()
@torch.no_grad() @torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
""" """

View File

@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval: float = 1000, growth_interval: float = 1000,
hysteresis: float = 2, hysteresis: float = 2,
max_scale: int = 2**32, max_scale: int = 2**32,
use_memory_tracer=False,
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None: mp_process_group: Optional[ProcessGroup] = None) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'

View File

@ -8,11 +8,8 @@ from .tensorful_state import StatefulTensor, TensorState
class ShardedParamV2(object): class ShardedParamV2(object):
def __init__(self, def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
param: torch.nn.Parameter, self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
process_group: Optional[dist.ProcessGroup] = None,
rm_torch_payload=False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel # This attribute must be initialized in ShardedModel

View File

@ -1,22 +1,20 @@
import torch import torch
import torch.distributed as dist
from typing import Optional
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from typing import Optional
class ShardedTensor(StatefulTensor): class ShardedTensor(StatefulTensor):
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None: def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
r""" r"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
""" """
super().__init__(tensor) super().__init__(tensor, state)
self.trans_state(TensorState.HOLD)
# kept the shape, numel and dtype of the init tensor.
self._origin_shape = tensor.shape self._origin_shape = tensor.shape
self._origin_numel = tensor.numel() self._origin_numel = tensor.numel()
self._origin_dtype = tensor.dtype self._origin_dtype = tensor.dtype
self._is_sharded = False self._is_sharded = False
@property @property

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from logging import NullHandler from typing import Optional
import torch import torch
@ -8,22 +8,22 @@ class TensorState(Enum):
HOLD = 1 HOLD = 1
HOLD_AFTER_FWD = 2 HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3 HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(object): class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states. """A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818 https://arxiv.org/abs/2108.05818
""" """
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state self._state = state
if state is not TensorState.FREE:
self._payload = tensor self._payload = tensor
else: if self._state == TensorState.FREE:
self._payload = None assert self._payload is None, f"payload has to None if {self._state}"
def data_ptr(self): def data_ptr(self):
if self._payload is None: if self._payload is None:

View File

@ -48,7 +48,7 @@ def _run_shard_param_v2(rank, world_size, port):
param = torch.nn.Parameter(torch.randn(2, 3)) param = torch.nn.Parameter(torch.randn(2, 3))
param_ref = deepcopy(param) param_ref = deepcopy(param)
sparam = ShardedParamV2(param=param, process_group=None) sparam = ShardedParamV2(param=param)
allclose(sparam.sharded_data_tensor.payload, param_ref.data) allclose(sparam.sharded_data_tensor.payload, param_ref.data)