mirror of https://github.com/hpcaitech/ColossalAI
[doc] Update docstring for ZeRO (#459)
* polish sharded model docstr * polish sharded optim docstr * polish zero docstr * polish shard strategy docstrpull/461/head
parent
84fd7c1d4d
commit
fc8e6db005
|
@ -10,6 +10,11 @@ from .tensor_shard_strategy import TensorShardStrategy
|
|||
|
||||
|
||||
class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
|
||||
which will fully utilize network bandwidth.
|
||||
It is especially useful when sub-module contains bias,
|
||||
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small).
|
||||
"""
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
||||
|
|
|
@ -9,6 +9,8 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
|||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
"""A naive implementation which shard each tensor evenly over all ranks
|
||||
"""
|
||||
|
||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
for t in tensor_list:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import functools
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -16,7 +16,6 @@ from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
|
|||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
@ -25,10 +24,34 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso
|
|||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
"""A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3.
|
||||
Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically
|
||||
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
|
||||
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
|
||||
|
||||
:param module: A sharded module, which must be initialized by `ZeroInitContext`.
|
||||
:type module: nn.Module
|
||||
:param shard_strategy: A shard strategy to manage shard behavior.
|
||||
:type shard_strategy: BaseShardStrategy
|
||||
:param process_group: Data parallel process group, defaults to None
|
||||
:type process_group: Optional[ProcessGroup], optional
|
||||
:param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`.
|
||||
:type reduce_scatter_process_group: Optional[ProcessGroup], optional
|
||||
:param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25
|
||||
:type reduce_scatter_bucket_size_mb: int, optional
|
||||
:param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False
|
||||
:type fp32_reduce_scatter: bool, optional
|
||||
:param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None
|
||||
:type offload_config: Optional[dict], optional
|
||||
:param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0
|
||||
:type gradient_predivide_factor: Optional[float], optional
|
||||
:param use_memory_tracer: Whether to use memoty tracer, defaults to False
|
||||
:type use_memory_tracer: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
shard_strategy: Type[BaseShardStrategy],
|
||||
shard_strategy: BaseShardStrategy,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
|
@ -36,10 +59,6 @@ class ShardedModelV2(nn.Module):
|
|||
offload_config: Optional[dict] = None,
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
use_memory_tracer: bool = False):
|
||||
r"""
|
||||
A demo to reconfigure zero1 shared_model.
|
||||
Currently do not consider the Optimizer States.
|
||||
"""
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
|
|
|
@ -25,6 +25,46 @@ class OptimState(Enum):
|
|||
|
||||
|
||||
class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3.
|
||||
You must use `ShardedOptimizerV2` with `ShardedModelV2`.
|
||||
|
||||
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
||||
shard strategy provided by sharded model to shard param fp32 tensors.
|
||||
:type sharded_model: sharded_model
|
||||
|
||||
:param optimizer: A Optimizer instance.
|
||||
:type optimizer: Optimizer
|
||||
|
||||
:param cpu_offload: is offloading the optimizer states to CPU.
|
||||
:type cpu_offload: bool
|
||||
|
||||
:param initial_scale: initial scale used by DynamicGradScaler
|
||||
:type initial_scale: float
|
||||
|
||||
:param min_scale: min scale used by DynamicGradScaler
|
||||
:type min_scale: float
|
||||
|
||||
:param growth_factor: growth_factor used by DynamicGradScaler
|
||||
:type growth_factor: float
|
||||
|
||||
:param backoff_factor: backoff_factor used by DynamicGradScaler
|
||||
:type backoff_factor: float
|
||||
|
||||
:param growth_interval: growth_interval used by DynamicGradScaler
|
||||
:type growth_interval: float
|
||||
|
||||
:param hysteresis: hysteresis used by DynamicGradScaler
|
||||
:type hysteresis: float
|
||||
|
||||
:param max_scale: max_scale used by DynamicGradScaler
|
||||
:type max_scale: float
|
||||
|
||||
:param dp_process_group: data paralle process group
|
||||
:type dp_process_group: Optional[ProcessGroup]
|
||||
|
||||
:param mp_process_group: model paralle process group
|
||||
:type mp_process_group: Optional[ProcessGroup]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sharded_model: ShardedModelV2,
|
||||
|
@ -39,47 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
max_scale: int = 2**32,
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||
"""
|
||||
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
||||
shard strategy provided by sharded model to shard param fp32 tensors.
|
||||
:type sharded_model: sharded_model
|
||||
|
||||
:param optimizer_class: A class type of Optimizer
|
||||
:type optimizer_class: Type[Optimizer]
|
||||
|
||||
:param cpu_offload: is offloading the optimizer states to CPU.
|
||||
:type cpu_offload: bool
|
||||
|
||||
:param initial_scale: initial scale used by DynamicGradScaler
|
||||
:type initial_scale: float
|
||||
|
||||
:param min_scale: min scale used by DynamicGradScaler
|
||||
:type min_scale: float
|
||||
|
||||
:param growth_factor: growth_factor used by DynamicGradScaler
|
||||
:type growth_factor: float
|
||||
|
||||
:param backoff_factor: backoff_factor used by DynamicGradScaler
|
||||
:type backoff_factor: float
|
||||
|
||||
:param growth_interval: growth_interval used by DynamicGradScaler
|
||||
:type growth_interval: float
|
||||
|
||||
:param hysteresis: hysteresis used by DynamicGradScaler
|
||||
:type hysteresis: float
|
||||
|
||||
:param max_scale: max_scale used by DynamicGradScaler
|
||||
:type max_scale: float
|
||||
|
||||
:param dp_process_group: data paralle process group
|
||||
:type dp_process_group: Optional[ProcessGroup]
|
||||
|
||||
:param mp_process_group: model paralle process group
|
||||
:type mp_process_group: Optional[ProcessGroup]
|
||||
|
||||
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
|
||||
:type defaults: dict()
|
||||
"""
|
||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||
|
||||
super().__init__(optimizer)
|
||||
|
|
Loading…
Reference in New Issue