[doc] Update docstring for ZeRO (#459)

* polish sharded model docstr

* polish sharded optim docstr

* polish zero docstr

* polish shard strategy docstr
pull/461/head
ver217 2022-03-18 16:48:20 +08:00 committed by GitHub
parent 84fd7c1d4d
commit fc8e6db005
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 48 deletions

View File

@ -10,6 +10,11 @@ from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy):
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
which will fully utilize network bandwidth.
It is especially useful when sub-module contains bias,
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small).
"""
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]

View File

@ -9,6 +9,8 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy):
"""A naive implementation which shard each tensor evenly over all ranks
"""
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
for t in tensor_list:

View File

@ -1,6 +1,6 @@
import functools
from collections import OrderedDict
from typing import Any, Optional, Type
from typing import Any, Optional
import torch
import torch.distributed as dist
@ -16,7 +16,6 @@ from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@ -25,10 +24,34 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso
class ShardedModelV2(nn.Module):
"""A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3.
Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
:param module: A sharded module, which must be initialized by `ZeroInitContext`.
:type module: nn.Module
:param shard_strategy: A shard strategy to manage shard behavior.
:type shard_strategy: BaseShardStrategy
:param process_group: Data parallel process group, defaults to None
:type process_group: Optional[ProcessGroup], optional
:param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`.
:type reduce_scatter_process_group: Optional[ProcessGroup], optional
:param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25
:type reduce_scatter_bucket_size_mb: int, optional
:param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False
:type fp32_reduce_scatter: bool, optional
:param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None
:type offload_config: Optional[dict], optional
:param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0
:type gradient_predivide_factor: Optional[float], optional
:param use_memory_tracer: Whether to use memoty tracer, defaults to False
:type use_memory_tracer: bool, optional
"""
def __init__(self,
module: nn.Module,
shard_strategy: Type[BaseShardStrategy],
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
@ -36,10 +59,6 @@ class ShardedModelV2(nn.Module):
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
use_memory_tracer: bool = False):
r"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
"""
super().__init__()
self.logger = get_dist_logger()

View File

@ -25,6 +25,46 @@ class OptimState(Enum):
class ShardedOptimizerV2(ColossalaiOptimizer):
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3.
You must use `ShardedOptimizerV2` with `ShardedModelV2`.
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
:type sharded_model: sharded_model
:param optimizer: A Optimizer instance.
:type optimizer: Optimizer
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param initial_scale: initial scale used by DynamicGradScaler
:type initial_scale: float
:param min_scale: min scale used by DynamicGradScaler
:type min_scale: float
:param growth_factor: growth_factor used by DynamicGradScaler
:type growth_factor: float
:param backoff_factor: backoff_factor used by DynamicGradScaler
:type backoff_factor: float
:param growth_interval: growth_interval used by DynamicGradScaler
:type growth_interval: float
:param hysteresis: hysteresis used by DynamicGradScaler
:type hysteresis: float
:param max_scale: max_scale used by DynamicGradScaler
:type max_scale: float
:param dp_process_group: data paralle process group
:type dp_process_group: Optional[ProcessGroup]
:param mp_process_group: model paralle process group
:type mp_process_group: Optional[ProcessGroup]
"""
def __init__(self,
sharded_model: ShardedModelV2,
@ -39,47 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
:type sharded_model: sharded_model
:param optimizer_class: A class type of Optimizer
:type optimizer_class: Type[Optimizer]
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param initial_scale: initial scale used by DynamicGradScaler
:type initial_scale: float
:param min_scale: min scale used by DynamicGradScaler
:type min_scale: float
:param growth_factor: growth_factor used by DynamicGradScaler
:type growth_factor: float
:param backoff_factor: backoff_factor used by DynamicGradScaler
:type backoff_factor: float
:param growth_interval: growth_interval used by DynamicGradScaler
:type growth_interval: float
:param hysteresis: hysteresis used by DynamicGradScaler
:type hysteresis: float
:param max_scale: max_scale used by DynamicGradScaler
:type max_scale: float
:param dp_process_group: data paralle process group
:type dp_process_group: Optional[ProcessGroup]
:param mp_process_group: model paralle process group
:type mp_process_group: Optional[ProcessGroup]
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
super().__init__(optimizer)