[Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/5931/head
Edenzzzz 4 months ago committed by GitHub
parent e86127925a
commit 8cc8f645cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,5 @@
import argparse import argparse
from contextlib import nullcontext
from typing import Callable, List, Union from typing import Callable, List, Union
import evaluate import evaluate
@ -17,6 +18,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
# ============================== # ==============================
@ -186,7 +188,6 @@ def main():
help="only gpt2 now", help="only gpt2 now",
) )
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "gpt2": if args.model_type == "gpt2":
@ -250,6 +251,12 @@ def main():
pad_token_id=data_builder.tokenizer.pad_token_id, pad_token_id=data_builder.tokenizer.pad_token_id,
) )
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin))
else nullcontext()
)
with init_ctx:
if model_name == "gpt2": if model_name == "gpt2":
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
else: else:

@ -1,4 +1,5 @@
import time import time
from contextlib import nullcontext
import torch import torch
import tqdm import tqdm
@ -8,9 +9,11 @@ from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -62,14 +65,6 @@ def main():
if args.mem_cap > 0: if args.mem_cap > 0:
colo_memory_cap(args.mem_cap) colo_memory_cap(args.mem_cap)
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM(config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set plugin # Set plugin
booster_kwargs = {} booster_kwargs = {}
if args.plugin == "torch_ddp_fp16": if args.plugin == "torch_ddp_fp16":
@ -82,6 +77,19 @@ def main():
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Build OPT model
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin))
else nullcontext()
)
config = AutoConfig.from_pretrained(args.model_name_or_path)
with init_ctx:
model = OPTForCausalLM(config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set optimizer # Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate) optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)

@ -1,3 +1,5 @@
from contextlib import nullcontext
import datasets import datasets
import torch import torch
import transformers import transformers
@ -8,9 +10,11 @@ from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_s
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
@ -78,14 +82,6 @@ def main():
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set plugin # Set plugin
booster_kwargs = {} booster_kwargs = {}
if args.plugin == "torch_ddp_fp16": if args.plugin == "torch_ddp_fp16":
@ -110,6 +106,21 @@ def main():
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
# Build OPT model
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext()
)
with init_ctx:
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader # Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer) dataset = NetflixDataset(tokenizer)

Loading…
Cancel
Save