mirror of https://github.com/hpcaitech/ColossalAI
fix zero3 fp16 and add zero3 model context (#62)
parent
9a0466534c
commit
7d3711058f
|
@ -220,7 +220,9 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
|||
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
sync_model_param_in_dp(model)
|
||||
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
||||
if not use_zero3:
|
||||
sync_model_param_in_dp(model)
|
||||
|
||||
# check amp and zero
|
||||
fp16_cfg = gpc.config.get('fp16', None)
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2
|
||||
from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
|
||||
|
@ -12,11 +15,11 @@ def convert_to_zero(model: nn.Module,
|
|||
level: int,
|
||||
zero_config):
|
||||
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
|
||||
|
||||
if is_no_pp_or_last_stage():
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
else:
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
if level == 2:
|
||||
if is_no_pp_or_last_stage():
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
else:
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
|
||||
if level == 2:
|
||||
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)
|
||||
|
@ -25,4 +28,71 @@ def convert_to_zero(model: nn.Module,
|
|||
return model, optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']
|
||||
def zero3_model_context(dtype=torch.half):
|
||||
"""A context to enable massive model construction for training with
|
||||
ZeRO-3. Models are automatically partitioned (or, sharded) across the
|
||||
system and converted to half precision. Note that the config of ZeRO-3 will be loaded automatically from `gpc.config`.
|
||||
|
||||
Args:
|
||||
dtype (``dtype``, optional): Can be used to change the data type of the parameters.
|
||||
Supported options are ``torch.half`` and ``torch.float``. Defaults to ``torch.half``
|
||||
|
||||
This context accelerates model initialization and enables models that
|
||||
are too large to allocate in their entirety in CPU memory. It has the
|
||||
following effects:
|
||||
|
||||
#. allocates tensors to either GPU or CPU memory or NVMe
|
||||
#. converts floating point tensors to half precision
|
||||
#. immediately partitions tensors among the group of data-parallel devices
|
||||
#. (*optional*) replaces ``torch.nn.functional.linear`` with a more
|
||||
memory-efficient implementation
|
||||
|
||||
These modifications allow for models that exceed the size of local CPU/GPU
|
||||
memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
|
||||
or GPU memory or NVMe) across all nodes. Consider initializing a model with one
|
||||
trillion parameters, whose weights occupy two terabytes (TB) in half
|
||||
precision. The initial CPU allocation in full precision requires 4TB of
|
||||
memory *per process*, and so a system with 8 GPUs per node would need 32TB of
|
||||
CPU memory due to data-parallel redundancies. Instead, by immediately
|
||||
partitioning tensors we remove the redundancies. The result is that
|
||||
regardless of the number of GPUs, we still only require the original 4TB. This
|
||||
allows for a linear increase in model size with the aggregate system memory.
|
||||
For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
|
||||
parameter model with 4 nodes and 32 GPUs.
|
||||
|
||||
Important: If the fp16 weights of the model can't fit onto a single GPU memory
|
||||
this feature must be used.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
#. Allocate a model and partition it among all processes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with zero3_model_context():
|
||||
model = MyLargeModel()
|
||||
|
||||
"""
|
||||
assert dtype == torch.half or dtype == torch.float, f'Invalid dtype, except torch.half or torch.float, got {dtype}'
|
||||
import deepspeed
|
||||
ds_config = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"zero_optimization": {
|
||||
"offload_param": getattr(gpc.config.zero, 'offload_param_config', None),
|
||||
"offload_optimizer": getattr(gpc.config.zero, 'offload_optimizer_config'),
|
||||
},
|
||||
"aio": getattr(gpc.config.zero, 'aio_config', None)
|
||||
}
|
||||
remote_device = getattr(ds_config['zero_optimization']['offload_param'], 'device', None)
|
||||
pin_memory = getattr(ds_config['zero_optimization']['offload_param'], 'pin_memory', False)
|
||||
return deepspeed.zero.Init(data_parallel_group=gpc.get_group(ParallelMode.DATA),
|
||||
remote_device=remote_device,
|
||||
config_dict_or_path=ds_config,
|
||||
pin_memory=pin_memory,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero', 'ZeroRedundancyOptimizer_Level_2',
|
||||
'ZeroRedundancyOptimizer_Level_3', 'zero3_model_context']
|
||||
|
|
|
@ -637,7 +637,8 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
|
|||
postscale_gradients=True,
|
||||
gradient_predivide_factor=1.0,
|
||||
gradient_accumulation_steps=1,
|
||||
aio_config=None):
|
||||
aio_config=None,
|
||||
dtype=torch.half):
|
||||
# mpu = None
|
||||
# mpu is removed from the parameter list
|
||||
# tensor parallel will be automatically detected later
|
||||
|
@ -682,13 +683,25 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
|
|||
util_ops = UtilsBuilder().load()
|
||||
self.flatten = util_ops.flatten
|
||||
self.unflatten = util_ops.unflatten
|
||||
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
|
||||
self.dtype = dtype
|
||||
|
||||
if not all(is_zero_param(p) for p in module.parameters()):
|
||||
ds_config = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"zero_optimization": {
|
||||
"offload_param": offload_param_config,
|
||||
"offload_optimizer": offload_optimizer_config,
|
||||
},
|
||||
"aio": aio_config
|
||||
}
|
||||
remote_device = offload_param_config['device']
|
||||
group = None
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
group = gpc.get_group(ParallelMode.DATA)
|
||||
Init(module=module, data_parallel_group=group, dtype=self.dtype)
|
||||
Init(module=module, data_parallel_group=group, dtype=self.dtype,
|
||||
remote_device=remote_device, config_dict_or_path=ds_config,
|
||||
pin_memory=offload_optimizer_config[OFFLOAD_OPTIMIZER_PIN_MEMORY])
|
||||
|
||||
for m in module.modules():
|
||||
_init_external_params(m)
|
||||
|
|
|
@ -83,4 +83,13 @@ Note that `fp16` is automatically enabled when using ZeRO. This relies on `AMP_T
|
|||
|
||||
### Training
|
||||
|
||||
Note that if your model is too large to fit within the memory when using ZeRO-3, you should use `colossalai.zero.zero3_model_context` to construct your model:
|
||||
|
||||
```python
|
||||
from colossalai.zero import zero3_model_context
|
||||
|
||||
with zero3_model_context():
|
||||
model = Model()
|
||||
```
|
||||
|
||||
Once you have completed your configuration, just use `colossalai.initialize()` to initialize your training.
|
||||
|
|
|
@ -23,7 +23,7 @@ ZeRO优化器可以切分三种模型状态(优化器状态、梯度、参数
|
|||
)
|
||||
|
||||
zero = dict(
|
||||
type='ZeroRedundancyOptimizer_Level_3',
|
||||
level=3,
|
||||
dynamic_loss_scale=True,
|
||||
clip_grad=1.0
|
||||
)
|
||||
|
@ -78,4 +78,13 @@ ZeRO优化器可以切分三种模型状态(优化器状态、梯度、参数
|
|||
|
||||
### 使用ZeRO优化器进行训练
|
||||
|
||||
注意,当使用ZeRO-3时,如果您的模型过大以至于无法放入内存, 您应该使用`colossalai.zero.zero3_model_context`来构建您的模型:
|
||||
|
||||
```python
|
||||
from colossalai.zero import zero3_model_context
|
||||
|
||||
with zero3_model_context():
|
||||
model = Model()
|
||||
```
|
||||
|
||||
如果您完成了上述配置,可以运行`colossalai.initialize()`来开始您的训练。
|
||||
|
|
Loading…
Reference in New Issue