[Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/5931/head
Edenzzzz 4 months ago committed by GitHub
parent e86127925a
commit 8cc8f645cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,5 @@
import argparse
from contextlib import nullcontext
from typing import Callable, List, Union
import evaluate
@ -17,6 +18,7 @@ from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
# ==============================
@ -186,7 +188,6 @@ def main():
help="only gpt2 now",
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()
if args.model_type == "gpt2":
@ -250,10 +251,16 @@ def main():
pad_token_id=data_builder.tokenizer.pad_token_id,
)
if model_name == "gpt2":
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
else:
raise RuntimeError
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin))
else nullcontext()
)
with init_ctx:
if model_name == "gpt2":
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
else:
raise RuntimeError
# optimizer
no_decay = ["bias", "LayerNorm.weight"]

@ -1,4 +1,5 @@
import time
from contextlib import nullcontext
import torch
import tqdm
@ -8,9 +9,11 @@ from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@ -62,14 +65,6 @@ def main():
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM(config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
@ -82,6 +77,19 @@ def main():
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Build OPT model
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin))
else nullcontext()
)
config = AutoConfig.from_pretrained(args.model_name_or_path)
with init_ctx:
model = OPTForCausalLM(config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)

@ -1,3 +1,5 @@
from contextlib import nullcontext
import datasets
import torch
import transformers
@ -8,9 +10,11 @@ from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_s
from transformers.utils.versions import require_version
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@ -78,14 +82,6 @@ def main():
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
@ -110,6 +106,21 @@ def main():
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
# Build OPT model
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext()
)
with init_ctx:
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer)

Loading…
Cancel
Save