ColossalAI/colossalai/zero/sharded_model/sharded_model_v2.py

283 lines
14 KiB
Python

import functools
from collections import OrderedDict
from typing import Any, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils.commons.memory import col_cuda_memory_capacity
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)
class ShardedModelV2(nn.Module):
"""A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3.
Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
:param module: A sharded module, which must be initialized by `ZeroInitContext`.
:type module: nn.Module
:param shard_strategy: A shard strategy to manage shard behavior.
:type shard_strategy: BaseShardStrategy
:param process_group: Data parallel process group, defaults to None
:type process_group: Optional[ProcessGroup], optional
:param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`.
:type reduce_scatter_process_group: Optional[ProcessGroup], optional
:param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25
:type reduce_scatter_bucket_size_mb: int, optional
:param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False
:type fp32_reduce_scatter: bool, optional
:param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None
:type offload_config: Optional[dict], optional
:param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0
:type gradient_predivide_factor: Optional[float], optional
:param use_memory_tracer: Whether to use memoty tracer, defaults to False
:type use_memory_tracer: bool, optional
"""
def __init__(self,
module: nn.Module,
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
use_memory_tracer: bool = False):
super().__init__()
self.logger = get_dist_logger()
# We force users to use ZeroInitContext
sharded = []
unsharded = []
for param in module.parameters():
assert hasattr(param, 'col_attr'), 'You must use ZeroInitContext to init your module first.'
sharded.append(param.col_attr.param_is_sharded)
unsharded.append(not param.col_attr.param_is_sharded)
assert all(sharded) or all(
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
self.shard_param = all(sharded)
self.module = module
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy
# Init Memory Statistics Collector
self._use_memory_tracer = use_memory_tracer
if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector()
else:
self._memstats_collector = None
self._iter_cnter = 0
# Register hooks
register_ophooks_recursively(self.module,
[ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)])
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
for param in module.parameters():
# Init `offload_fp32_grad`
param.col_attr.offload_fp32_grad = self._cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set
# gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if \
gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
self._require_backward_grad_sync: bool = True
self._cuda_margin_space = 0
@property
def cuda_margin_space(self):
return self._cuda_margin_space
@property
def cpu_offload(self):
return self._cpu_offload
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self._iter_cnter == 0 and self._memstats_collector:
# the opeartion will affect the flag in ZeroHook
self._memstats_collector.start_collection()
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs)
return outputs
def backward(self, loss):
loss.backward()
self._post_backward_operations()
def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
self._post_backward_operations()
@torch.no_grad()
def _post_backward_operations(self) -> None:
"""
The method includes operations required to be processed after backward
"""
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = col_cuda_memory_capacity() - max(self._memstats_collector._overall_cuda)
self._iter_cnter += 1
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream):
self.reducer.flush()
torch.cuda.current_stream().wait_stream(self.comm_stream)
if self._cpu_offload:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
self.reducer.free()
# In case some post bwd hook is not fired
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group)
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
if not self._require_backward_grad_sync:
continue
# Write grad back to p.grad and set p.col_attr.grad to None
# As sharded optimizer only update a shard of param,
# no matter whether we shard param in sharded model
# We have to make sure the grad is a flat tensor shard
# If world size == 1 and sharded param,
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if p.col_attr.offload_fp32_grad:
col_move_to_cpu(grad)
if p.col_attr.fp32_grad is not None:
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
grad = p.col_attr.fp32_grad
p.grad.data = grad.view(-1)
p.col_attr.fp16_grad = None
p.col_attr.fp32_grad = None
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
"""
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (GPU #0): [6, 8] # 1+5, 2+6
param.grad (GPU #1): [10, 12] # 3+7, 4+8
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param.col_attr.grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
if grad is None:
return
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
if not self._require_backward_grad_sync:
return
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
new_grad = grad.clone()
if self.fp32_reduce_scatter:
new_grad.data = new_grad.data.to(param.dtype)
if self.gradient_predivide_factor > 1.0:
# Average grad by world_size for consistency with PyTorch DDP.
new_grad.data.div_(self.gradient_predivide_factor)
orig_grad_data = new_grad.data
if self.world_size > 1:
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async(grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param))
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
torch.cuda.current_stream().wait_stream(self.comm_stream)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
param.col_attr.fp16_grad = reduced_grad.data
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
self.process_group)
prev_params = {}
for p in self.module.parameters():
prev_params[p] = p.data
p.data = p.col_attr.sharded_data_tensor.payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
self.process_group)
for p in self.module.parameters():
p.data = prev_params[p]
return gathered_state_dict
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
raise NotImplementedError
def __getitem__(self, idx: int):
assert isinstance(self.module, nn.ModuleList)
return self.module[idx]
def __len__(self):
assert isinstance(self.module, nn.ModuleList)
return len(self.module)
def __iter__(self):
assert isinstance(self.module, nn.ModuleList)
return iter(self.module)