|
|
|
@ -1,4 +1,3 @@
|
|
|
|
|
|
|
|
|
|
import functools
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
@ -7,11 +6,10 @@ import torch.distributed as dist
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
|
|
|
|
|
register_ophooks_recursively)
|
|
|
|
|
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
|
|
|
|
|
from colossalai.engine.paramhooks import BaseParamHookMgr
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.zero.shard_param import ShardParam
|
|
|
|
|
from colossalai.zero.sharded_param import ShardedParam
|
|
|
|
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
|
|
|
|
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
@ -21,17 +19,19 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShardedModelV2(nn.Module):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
module: nn.Module,
|
|
|
|
|
process_group: Optional[ProcessGroup] = None,
|
|
|
|
|
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
|
|
|
|
reduce_scatter_bucket_size_mb: int = 25,
|
|
|
|
|
reshard_after_forward: bool = True,
|
|
|
|
|
mixed_precision: bool = False,
|
|
|
|
|
fp32_reduce_scatter: bool = False,
|
|
|
|
|
offload_config: Optional[dict] = None,
|
|
|
|
|
gradient_predivide_factor: Optional[float] = 1.0,
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
module: nn.Module,
|
|
|
|
|
process_group: Optional[ProcessGroup] = None,
|
|
|
|
|
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
|
|
|
|
reduce_scatter_bucket_size_mb: int = 25,
|
|
|
|
|
reshard_after_forward: bool = True,
|
|
|
|
|
mixed_precision: bool = False,
|
|
|
|
|
fp32_reduce_scatter: bool = False,
|
|
|
|
|
offload_config: Optional[dict] = None,
|
|
|
|
|
gradient_predivide_factor: Optional[float] = 1.0,
|
|
|
|
|
):
|
|
|
|
|
r"""
|
|
|
|
|
A demo to reconfigure zero1 shared_model.
|
|
|
|
|
Currently do not consider the Optimizer States.
|
|
|
|
@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
|
|
|
|
|
# Shard the parameters at first
|
|
|
|
|
for _, param in self.module.named_parameters():
|
|
|
|
|
param.ca_attr = ShardParam(param)
|
|
|
|
|
param.ca_attr = ShardedParam(param)
|
|
|
|
|
param.ca_attr.shard()
|
|
|
|
|
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
|
|
|
|
|
|
|
|
@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
|
|
|
|
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
|
|
|
|
# So we use 1.0 as the default gradient_predivide_factor
|
|
|
|
|
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
|
|
|
|
|
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
|
|
|
|
|
# However, if you set gradient_predivide_factor to None,
|
|
|
|
|
# we will set gradient_predivide_factor to a value >= 1.0 automatically
|
|
|
|
|
self.gradient_predivide_factor: float = \
|
|
|
|
|
gradient_predivide_factor if gradient_predivide_factor is not None else \
|
|
|
|
|
get_gradient_predivide_factor(self.world_size)
|
|
|
|
|
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
|
|
|
|
|
|
|
|
@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
|
|
|
|
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
|
|
|
|
|
full gradient for the local batch. The reduce-scatter op will save
|
|
|
|
|
a single shard of the summed gradient across all
|
|
|
|
|
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
|
|
|
|
|
|
|
|
|
before reduce_scatter:
|
|
|
|
@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
orig_grad_data = new_grad.data
|
|
|
|
|
if self.world_size > 1:
|
|
|
|
|
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
|
|
|
|
self.reducer.reduce_scatter_async(
|
|
|
|
|
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
|
|
|
|
self.reducer.reduce_scatter_async(grad_chunks,
|
|
|
|
|
group=self.reduce_scatter_process_group,
|
|
|
|
|
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
|
|
|
|
else:
|
|
|
|
|
self._reduce_scatter_callback(param, new_grad)
|
|
|
|
|
orig_grad_data.record_stream(self.comm_stream)
|
|
|
|
|