ColossalAI/colossalai/zero/sharded_model/sharded_model_v2.py

246 lines
12 KiB
Python

import functools
from collections import OrderedDict
from typing import Any, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils.commons.memory import col_cuda_memory_capacity
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)
class ShardedModelV2(nn.Module):
def __init__(self,
module: nn.Module,
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
shard_param: bool = True,
use_memory_tracer: bool = False):
r"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
"""
super().__init__()
self.logger = get_dist_logger()
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
# Cast module to fp16 and cuda, in case user didn't use ZeroInitContext
self.module = module.half().cuda()
self.shard_strategy = shard_strategy
self.shard_param = shard_param
# In case user didn't use ZeroInitContext
for param in self.module.parameters():
if not hasattr(param, 'col_attr'):
param.col_attr = ShardedParamV2(param, process_group, rm_torch_payload=True)
if self.shard_param:
self.shard_strategy.shard([param.col_attr.data])
# Init Memory Statistics Collector
self._use_memory_tracer = use_memory_tracer
if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector()
else:
self._memstats_collector = None
self._iter_cnter = 0
# Register hooks
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)])
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set
# gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if \
gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
self._require_backward_grad_sync: bool = True
self._cuda_margin_space = 0
@property
def cuda_margin_space(self):
return self._cuda_margin_space
@property
def cpu_offload(self):
return self._cpu_offload
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self._iter_cnter == 0 and self._memstats_collector:
# the opeartion will affect the flag in ZeroHook
self._memstats_collector.start_collection()
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs)
return outputs
def backward(self, loss):
loss.backward()
self._post_backward_operations()
def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
self._post_backward_operations()
@torch.no_grad()
def _post_backward_operations(self) -> None:
"""
The method includes operations required to be processed after backward
"""
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = col_cuda_memory_capacity() - max(self._memstats_collector._overall_cuda)
self._iter_cnter += 1
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream):
self.reducer.flush()
torch.cuda.current_stream().wait_stream(self.comm_stream)
if self._cpu_offload:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
self.reducer.free()
# In case some post bwd hook is not fired
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data])
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
continue
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
if not self._require_backward_grad_sync:
continue
# Write grad back to p.grad and set p.col_attr.grad to None
# As sharded optimizer only update a shard of param,
# no matter whether we shard param in sharded model
# We have to make sure the grad is a flat tensor shard
# If world size == 1 and sharded param,
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if self._cpu_offload:
col_move_to_cpu(grad)
if p.col_attr.fp32_grad is not None:
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
grad = p.col_attr.fp32_grad
p.grad.data = grad.view(-1)
p.col_attr.fp16_grad = None
p.col_attr.fp32_grad = None
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
"""
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (GPU #0): [6, 8] # 1+5, 2+6
param.grad (GPU #1): [10, 12] # 3+7, 4+8
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param.col_attr.grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
if grad is None:
return
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
if not self._require_backward_grad_sync:
return
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
new_grad = grad.clone()
if self.fp32_reduce_scatter:
new_grad.data = new_grad.data.to(param.dtype)
if self.gradient_predivide_factor > 1.0:
# Average grad by world_size for consistency with PyTorch DDP.
new_grad.data.div_(self.gradient_predivide_factor)
orig_grad_data = new_grad.data
if self.world_size > 1:
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async(grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param))
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
torch.cuda.current_stream().wait_stream(self.comm_stream)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
param.col_attr.fp16_grad = reduced_grad.data
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])
prev_params = {}
for p in self.module.parameters():
prev_params[p] = p.data
p.data = p.col_attr.data.payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()])
for p in self.module.parameters():
p.data = prev_params[p]
return gathered_state_dict
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
raise NotImplementedError