[Gemini] add GeminiAdamOptimizer (#1960)

pull/1962/head
Jiarui Fang 2022-11-16 14:44:28 +08:00 committed by GitHub
parent 7066dfbf82
commit f7e276fa71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 66 additions and 44 deletions

View File

@ -0,0 +1,15 @@
from typing import Any
import torch
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
__all__ = ['GeminiAdamOptimizer']
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)

View File

@ -1,8 +1,10 @@
from typing import Any, Optional
import torch
from colossalai.utils import multi_tensor_applier
from colossalai.registry import OPTIMIZERS
from typing import Optional
from colossalai.utils import multi_tensor_applier
from .nvme_optimizer import NVMeOptimizer
@ -11,7 +13,7 @@ class HybridAdam(NVMeOptimizer):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device:
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
@ -43,7 +45,7 @@ class HybridAdam(NVMeOptimizer):
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
@ -68,14 +70,15 @@ class HybridAdam(NVMeOptimizer):
weight_decay=0,
adamw_mode=True,
nvme_offload_fraction: float = 0.0,
nvme_offload_dir: Optional[str] = None):
nvme_offload_dir: Optional[str] = None,
**defaults: Any):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
try:
import cpu_adam
import colossal_C
import cpu_adam
except ImportError:
raise ImportError('Please install colossalai from source code to use HybridAdam')

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, Set, Tuple
from typing import Any, Dict, Set, Tuple
import torch
import torch.distributed as dist
@ -55,7 +55,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32):
max_scale: float = 2**32,
**defaults: Any):
super().__init__(optim)
assert isinstance(module, ZeroDDP)
self.module = module

View File

@ -16,8 +16,9 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32: bool = False,
search_range_mb: int = 32) -> None:
"""
A torch.Module warpper using ZeRODPP and Genimi.
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext

View File

@ -7,7 +7,7 @@ from colossalai.logging import get_dist_logger
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
from .zero_optimizer import ZeroOptimizer
from ..nn.optimizer.zero_optimizer import ZeroOptimizer
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,

View File

@ -3,10 +3,8 @@ This example shows how to use Colossal-AI to run huggingface GPT training in dis
## GPT
We use the GPT2 model from huggingface transformers. The input data is randonly generated.
## Our Modifications
The `train_gpt_demo.py` provides three distributed plans, i.e. Colossal-AI, PyTorch DDP and ZeRO.
The Colossal-AI leverages Tensor Parallel and Gemini.
The `train_gpt_demo.py` provides three distributed plans, i.e. ColossalAI, PyTorch DDP and ZeRO.
The ColossalAI leverages Tensor Parallel and Gemini.
## Quick Start
You can launch training by using the following bash script.

View File

@ -10,11 +10,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import GPT2Config, GPT2LMHeadModel
@ -222,7 +223,7 @@ def main():
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
# build GPT model
with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg):
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
model = gpt2_medium(checkpoint=True)
pg = default_pg
@ -232,8 +233,9 @@ def main():
model = gemini_zero_dpp(model, pg, args.placement)
# build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan == "ddp":

View File

@ -43,11 +43,11 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,

View File

@ -30,13 +30,24 @@ from itertools import chain
import datasets
import torch
import torch.distributed as dist
import transformers
from accelerate.utils import set_seed
from context import barrier_context
from datasets import load_dataset
from packaging import version
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@ -50,17 +61,6 @@ from transformers import (
)
from transformers.utils.versions import require_version
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())

View File

@ -12,12 +12,12 @@ from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal

View File

@ -9,12 +9,12 @@ import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed

View File

@ -7,16 +7,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.gemini.chunk import search_chunk_configuration
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
@ -96,19 +94,23 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
model = GeminiDDP(model, init_device, placement_policy, True, False, 32)
# The same as the following 3 lines
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
# The same as the following 2 lines
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
print(chunk_manager)
check_param(model, torch_model, pg)
model.eval()