mirror of https://github.com/hpcaitech/ColossalAI
[doc] Update docstring for ZeRO (#459)
* polish sharded model docstr * polish sharded optim docstr * polish zero docstr * polish shard strategy docstrpull/461/head
parent
84fd7c1d4d
commit
fc8e6db005
|
@ -10,6 +10,11 @@ from .tensor_shard_strategy import TensorShardStrategy
|
||||||
|
|
||||||
|
|
||||||
class BucketTensorShardStrategy(TensorShardStrategy):
|
class BucketTensorShardStrategy(TensorShardStrategy):
|
||||||
|
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
|
||||||
|
which will fully utilize network bandwidth.
|
||||||
|
It is especially useful when sub-module contains bias,
|
||||||
|
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small).
|
||||||
|
"""
|
||||||
|
|
||||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||||
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
||||||
|
|
|
@ -9,6 +9,8 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
|
||||||
|
|
||||||
class TensorShardStrategy(BaseShardStrategy):
|
class TensorShardStrategy(BaseShardStrategy):
|
||||||
|
"""A naive implementation which shard each tensor evenly over all ranks
|
||||||
|
"""
|
||||||
|
|
||||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||||
for t in tensor_list:
|
for t in tensor_list:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import functools
|
import functools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Optional, Type
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -16,7 +16,6 @@ from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
|
||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
@ -25,10 +24,34 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
|
"""A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3.
|
||||||
|
Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically
|
||||||
|
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
|
||||||
|
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
|
||||||
|
|
||||||
|
:param module: A sharded module, which must be initialized by `ZeroInitContext`.
|
||||||
|
:type module: nn.Module
|
||||||
|
:param shard_strategy: A shard strategy to manage shard behavior.
|
||||||
|
:type shard_strategy: BaseShardStrategy
|
||||||
|
:param process_group: Data parallel process group, defaults to None
|
||||||
|
:type process_group: Optional[ProcessGroup], optional
|
||||||
|
:param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`.
|
||||||
|
:type reduce_scatter_process_group: Optional[ProcessGroup], optional
|
||||||
|
:param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25
|
||||||
|
:type reduce_scatter_bucket_size_mb: int, optional
|
||||||
|
:param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False
|
||||||
|
:type fp32_reduce_scatter: bool, optional
|
||||||
|
:param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None
|
||||||
|
:type offload_config: Optional[dict], optional
|
||||||
|
:param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0
|
||||||
|
:type gradient_predivide_factor: Optional[float], optional
|
||||||
|
:param use_memory_tracer: Whether to use memoty tracer, defaults to False
|
||||||
|
:type use_memory_tracer: bool, optional
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
shard_strategy: Type[BaseShardStrategy],
|
shard_strategy: BaseShardStrategy,
|
||||||
process_group: Optional[ProcessGroup] = None,
|
process_group: Optional[ProcessGroup] = None,
|
||||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||||
reduce_scatter_bucket_size_mb: int = 25,
|
reduce_scatter_bucket_size_mb: int = 25,
|
||||||
|
@ -36,10 +59,6 @@ class ShardedModelV2(nn.Module):
|
||||||
offload_config: Optional[dict] = None,
|
offload_config: Optional[dict] = None,
|
||||||
gradient_predivide_factor: Optional[float] = 1.0,
|
gradient_predivide_factor: Optional[float] = 1.0,
|
||||||
use_memory_tracer: bool = False):
|
use_memory_tracer: bool = False):
|
||||||
r"""
|
|
||||||
A demo to reconfigure zero1 shared_model.
|
|
||||||
Currently do not consider the Optimizer States.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,46 @@ class OptimState(Enum):
|
||||||
|
|
||||||
|
|
||||||
class ShardedOptimizerV2(ColossalaiOptimizer):
|
class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
|
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3.
|
||||||
|
You must use `ShardedOptimizerV2` with `ShardedModelV2`.
|
||||||
|
|
||||||
|
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
||||||
|
shard strategy provided by sharded model to shard param fp32 tensors.
|
||||||
|
:type sharded_model: sharded_model
|
||||||
|
|
||||||
|
:param optimizer: A Optimizer instance.
|
||||||
|
:type optimizer: Optimizer
|
||||||
|
|
||||||
|
:param cpu_offload: is offloading the optimizer states to CPU.
|
||||||
|
:type cpu_offload: bool
|
||||||
|
|
||||||
|
:param initial_scale: initial scale used by DynamicGradScaler
|
||||||
|
:type initial_scale: float
|
||||||
|
|
||||||
|
:param min_scale: min scale used by DynamicGradScaler
|
||||||
|
:type min_scale: float
|
||||||
|
|
||||||
|
:param growth_factor: growth_factor used by DynamicGradScaler
|
||||||
|
:type growth_factor: float
|
||||||
|
|
||||||
|
:param backoff_factor: backoff_factor used by DynamicGradScaler
|
||||||
|
:type backoff_factor: float
|
||||||
|
|
||||||
|
:param growth_interval: growth_interval used by DynamicGradScaler
|
||||||
|
:type growth_interval: float
|
||||||
|
|
||||||
|
:param hysteresis: hysteresis used by DynamicGradScaler
|
||||||
|
:type hysteresis: float
|
||||||
|
|
||||||
|
:param max_scale: max_scale used by DynamicGradScaler
|
||||||
|
:type max_scale: float
|
||||||
|
|
||||||
|
:param dp_process_group: data paralle process group
|
||||||
|
:type dp_process_group: Optional[ProcessGroup]
|
||||||
|
|
||||||
|
:param mp_process_group: model paralle process group
|
||||||
|
:type mp_process_group: Optional[ProcessGroup]
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
sharded_model: ShardedModelV2,
|
sharded_model: ShardedModelV2,
|
||||||
|
@ -39,47 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
max_scale: int = 2**32,
|
max_scale: int = 2**32,
|
||||||
dp_process_group: Optional[ProcessGroup] = None,
|
dp_process_group: Optional[ProcessGroup] = None,
|
||||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||||
"""
|
|
||||||
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
|
||||||
shard strategy provided by sharded model to shard param fp32 tensors.
|
|
||||||
:type sharded_model: sharded_model
|
|
||||||
|
|
||||||
:param optimizer_class: A class type of Optimizer
|
|
||||||
:type optimizer_class: Type[Optimizer]
|
|
||||||
|
|
||||||
:param cpu_offload: is offloading the optimizer states to CPU.
|
|
||||||
:type cpu_offload: bool
|
|
||||||
|
|
||||||
:param initial_scale: initial scale used by DynamicGradScaler
|
|
||||||
:type initial_scale: float
|
|
||||||
|
|
||||||
:param min_scale: min scale used by DynamicGradScaler
|
|
||||||
:type min_scale: float
|
|
||||||
|
|
||||||
:param growth_factor: growth_factor used by DynamicGradScaler
|
|
||||||
:type growth_factor: float
|
|
||||||
|
|
||||||
:param backoff_factor: backoff_factor used by DynamicGradScaler
|
|
||||||
:type backoff_factor: float
|
|
||||||
|
|
||||||
:param growth_interval: growth_interval used by DynamicGradScaler
|
|
||||||
:type growth_interval: float
|
|
||||||
|
|
||||||
:param hysteresis: hysteresis used by DynamicGradScaler
|
|
||||||
:type hysteresis: float
|
|
||||||
|
|
||||||
:param max_scale: max_scale used by DynamicGradScaler
|
|
||||||
:type max_scale: float
|
|
||||||
|
|
||||||
:param dp_process_group: data paralle process group
|
|
||||||
:type dp_process_group: Optional[ProcessGroup]
|
|
||||||
|
|
||||||
:param mp_process_group: model paralle process group
|
|
||||||
:type mp_process_group: Optional[ProcessGroup]
|
|
||||||
|
|
||||||
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
|
|
||||||
:type defaults: dict()
|
|
||||||
"""
|
|
||||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||||
|
|
||||||
super().__init__(optimizer)
|
super().__init__(optimizer)
|
||||||
|
|
Loading…
Reference in New Issue