mirror of https://github.com/hpcaitech/ColossalAI
[zero] label state for param fp16 and grad (#551)
parent
92f4224867
commit
f552b11294
|
@ -6,6 +6,7 @@ from colossalai.registry import OPHOOKS
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
|
|
||||||
from ._base_ophook import BaseOpHook
|
from ._base_ophook import BaseOpHook
|
||||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
|
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
|
||||||
|
@ -42,7 +43,13 @@ class ZeroHook(BaseOpHook):
|
||||||
if self._memstarts_collector:
|
if self._memstarts_collector:
|
||||||
self._memstarts_collector.sample_memstats()
|
self._memstarts_collector.sample_memstats()
|
||||||
|
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
|
||||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||||
|
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'col_attr')
|
||||||
|
@ -65,7 +72,10 @@ class ZeroHook(BaseOpHook):
|
||||||
if param.col_attr.bwd_count == 0:
|
if param.col_attr.bwd_count == 0:
|
||||||
# We haven't stored local accumulated grad yet
|
# We haven't stored local accumulated grad yet
|
||||||
assert param.col_attr.fp32_grad.is_null()
|
assert param.col_attr.fp32_grad.is_null()
|
||||||
|
|
||||||
|
# Allocate grad fp32 memory space here
|
||||||
param.col_attr.fp32_grad.reset_payload(param.grad.data)
|
param.col_attr.fp32_grad.reset_payload(param.grad.data)
|
||||||
|
# TODO(jiaruifang) we should set grad fp16 state to HOLD here.
|
||||||
param.grad = None
|
param.grad = None
|
||||||
else:
|
else:
|
||||||
# We have stored local accumulated grad
|
# We have stored local accumulated grad
|
||||||
|
@ -75,12 +85,19 @@ class ZeroHook(BaseOpHook):
|
||||||
if self._memstarts_collector:
|
if self._memstarts_collector:
|
||||||
self._memstarts_collector.sample_memstats()
|
self._memstarts_collector.sample_memstats()
|
||||||
|
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
|
||||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||||
|
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'col_attr')
|
||||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.col_attr.remove_torch_payload()
|
param.col_attr.remove_torch_payload()
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
|
@ -158,12 +158,25 @@ class ShardedModelV2(nn.Module):
|
||||||
f.write(str(self._memstats_collector.non_model_data_cuda_GB))
|
f.write(str(self._memstats_collector.non_model_data_cuda_GB))
|
||||||
f.write('\n')
|
f.write('\n')
|
||||||
|
|
||||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
def _pre_forward_operations(self):
|
||||||
if self._iter_cnter == 0 and self._memstats_collector:
|
if self._iter_cnter == 0 and self._memstats_collector:
|
||||||
# the opeartion will affect the flag in ZeroHook
|
# the operation will affect the memory tracer behavior in ZeroHook
|
||||||
self._memstats_collector.start_collection()
|
self._memstats_collector.start_collection()
|
||||||
|
|
||||||
|
for p in self.module.parameters():
|
||||||
|
if hasattr(p, 'col_attr'):
|
||||||
|
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
|
def _post_forward_operations(self):
|
||||||
|
for p in self.module.parameters():
|
||||||
|
if hasattr(p, 'col_attr'):
|
||||||
|
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
|
self._pre_forward_operations()
|
||||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
|
self._post_forward_operations()
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
|
@ -195,9 +208,15 @@ class ShardedModelV2(nn.Module):
|
||||||
def _post_backward_operations(self) -> None:
|
def _post_backward_operations(self) -> None:
|
||||||
"""
|
"""
|
||||||
The method includes operations required to be processed after backward
|
The method includes operations required to be processed after backward
|
||||||
|
1. update memory tracer.
|
||||||
|
2. flush the gradient in buckets. Reducing partial gradients in each process.
|
||||||
|
3. shard tensors not dealed in the zero hook
|
||||||
|
4. move sharded param grad payload to param.grad
|
||||||
"""
|
"""
|
||||||
|
# 1. update memory tracer.
|
||||||
self._update_memstats()
|
self._update_memstats()
|
||||||
|
|
||||||
|
# 2. flush the gradient in buckets. Reducing partial gradients in each process.
|
||||||
if self._require_backward_grad_sync:
|
if self._require_backward_grad_sync:
|
||||||
# Flush any unreduced buckets in the post_backward stream.
|
# Flush any unreduced buckets in the post_backward stream.
|
||||||
with torch.cuda.stream(self.comm_stream):
|
with torch.cuda.stream(self.comm_stream):
|
||||||
|
@ -207,45 +226,51 @@ class ShardedModelV2(nn.Module):
|
||||||
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
|
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
|
||||||
torch.cuda.current_stream().synchronize()
|
torch.cuda.current_stream().synchronize()
|
||||||
self.reducer.free()
|
self.reducer.free()
|
||||||
# In case some post bwd hook is not fired
|
# 3. shard tensors not dealed in the zero hook
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if not p.col_attr.param_is_sharded:
|
if not p.col_attr.param_is_sharded:
|
||||||
tensor_list.append(p.col_attr.sharded_data_tensor)
|
tensor_list.append(p.col_attr.sharded_data_tensor)
|
||||||
|
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||||
|
|
||||||
|
# 4. move sharded param grad payload to param.grad
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
p.col_attr.bwd_count = 0
|
p.col_attr.bwd_count = 0
|
||||||
if not p.requires_grad:
|
if not p.requires_grad:
|
||||||
continue
|
continue
|
||||||
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
|
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
|
||||||
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
|
# NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient allreducing between process group.
|
||||||
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
|
# If _require_backward_grad_sync is True,
|
||||||
# sync passes, if desired.
|
# p.grad remains the accumulated unsharded gradient from prior no-sync passes.
|
||||||
|
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||||
if not self._require_backward_grad_sync:
|
if not self._require_backward_grad_sync:
|
||||||
continue
|
continue
|
||||||
# Write grad back to p.grad and set p.col_attr.grad to None
|
# Write grad payload kept by sharded param back to p.grad,
|
||||||
|
# and set p.col_attr.grad to None
|
||||||
# As sharded optimizer only update a shard of param,
|
# As sharded optimizer only update a shard of param,
|
||||||
# no matter whether we shard param in sharded model
|
# no matter whether we shard param in sharded model
|
||||||
# We have to make sure the grad is a flat tensor shard
|
# We have to make sure the grad is a flat tensor shard
|
||||||
# If world size == 1 and sharded param,
|
# If world size == 1 and param is sharded,
|
||||||
# the shape `grad` is the same as unsharded param
|
# the shape `grad` is the same as unsharded param
|
||||||
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
||||||
if self.reuse_fp16_shard:
|
if self.reuse_fp16_shard:
|
||||||
grad_payload = p.col_attr.sharded_data_tensor.payload
|
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
|
||||||
else:
|
else:
|
||||||
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
|
grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
|
||||||
assert isinstance(grad_payload, torch.Tensor)
|
assert isinstance(grad_fp16_payload, torch.Tensor)
|
||||||
if p.col_attr.offload_grad:
|
if p.col_attr.offload_grad:
|
||||||
colo_model_data_move_to_cpu(grad_payload)
|
colo_model_data_move_to_cpu(grad_fp16_payload)
|
||||||
if not p.col_attr.fp32_grad.is_null():
|
if not p.col_attr.fp32_grad.is_null():
|
||||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||||
p.col_attr.fp32_grad.payload.add_(grad_payload.view_as(p.col_attr.fp32_grad.payload))
|
p.col_attr.fp32_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.fp32_grad.payload))
|
||||||
grad_payload = p.col_attr.fp32_grad.payload
|
grad_fp16_payload = p.col_attr.fp32_grad.payload
|
||||||
p.grad.data = grad_payload
|
|
||||||
p.col_attr.fp16_grad.set_null()
|
|
||||||
p.col_attr.fp32_grad.set_null()
|
p.col_attr.fp32_grad.set_null()
|
||||||
|
|
||||||
|
p.grad.data = grad_fp16_payload
|
||||||
|
p.col_attr.fp16_grad.set_null()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
growth_interval: float = 1000,
|
growth_interval: float = 1000,
|
||||||
hysteresis: float = 2,
|
hysteresis: float = 2,
|
||||||
max_scale: int = 2**32,
|
max_scale: int = 2**32,
|
||||||
use_memory_tracer=False,
|
|
||||||
dp_process_group: Optional[ProcessGroup] = None,
|
dp_process_group: Optional[ProcessGroup] = None,
|
||||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||||
|
|
|
@ -8,11 +8,8 @@ from .tensorful_state import StatefulTensor, TensorState
|
||||||
|
|
||||||
class ShardedParamV2(object):
|
class ShardedParamV2(object):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
||||||
param: torch.nn.Parameter,
|
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||||
process_group: Optional[dist.ProcessGroup] = None,
|
|
||||||
rm_torch_payload=False) -> None:
|
|
||||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
|
||||||
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||||
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||||
# This attribute must be initialized in ShardedModel
|
# This attribute must be initialized in ShardedModel
|
||||||
|
|
|
@ -1,22 +1,20 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
from typing import Optional
|
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class ShardedTensor(StatefulTensor):
|
class ShardedTensor(StatefulTensor):
|
||||||
|
|
||||||
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
||||||
r"""
|
r"""
|
||||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||||
"""
|
"""
|
||||||
super().__init__(tensor)
|
super().__init__(tensor, state)
|
||||||
self.trans_state(TensorState.HOLD)
|
|
||||||
|
|
||||||
|
# kept the shape, numel and dtype of the init tensor.
|
||||||
self._origin_shape = tensor.shape
|
self._origin_shape = tensor.shape
|
||||||
self._origin_numel = tensor.numel()
|
self._origin_numel = tensor.numel()
|
||||||
self._origin_dtype = tensor.dtype
|
self._origin_dtype = tensor.dtype
|
||||||
|
|
||||||
self._is_sharded = False
|
self._is_sharded = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging import NullHandler
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,22 +8,22 @@ class TensorState(Enum):
|
||||||
HOLD = 1
|
HOLD = 1
|
||||||
HOLD_AFTER_FWD = 2
|
HOLD_AFTER_FWD = 2
|
||||||
HOLD_AFTER_BWD = 3
|
HOLD_AFTER_BWD = 3
|
||||||
|
COMPUTE = 4
|
||||||
|
|
||||||
|
|
||||||
class StatefulTensor(object):
|
class StatefulTensor(object):
|
||||||
"""A Structure stores a Torch Tensor and labeled states.
|
"""A Structure stores a Torch Tensor and labeled states.
|
||||||
|
Inspired from the paper:
|
||||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||||
|
|
||||||
https://arxiv.org/abs/2108.05818
|
https://arxiv.org/abs/2108.05818
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||||
self._state = state
|
self._state = state
|
||||||
if state is not TensorState.FREE:
|
|
||||||
self._payload = tensor
|
self._payload = tensor
|
||||||
else:
|
if self._state == TensorState.FREE:
|
||||||
self._payload = None
|
assert self._payload is None, f"payload has to None if {self._state}"
|
||||||
|
|
||||||
def data_ptr(self):
|
def data_ptr(self):
|
||||||
if self._payload is None:
|
if self._payload is None:
|
||||||
|
|
|
@ -48,7 +48,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
|
|
||||||
param = torch.nn.Parameter(torch.randn(2, 3))
|
param = torch.nn.Parameter(torch.randn(2, 3))
|
||||||
param_ref = deepcopy(param)
|
param_ref = deepcopy(param)
|
||||||
sparam = ShardedParamV2(param=param, process_group=None)
|
sparam = ShardedParamV2(param=param)
|
||||||
|
|
||||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue