mirror of https://github.com/hpcaitech/ColossalAI
[Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>pull/5931/head
parent
e86127925a
commit
8cc8f645cd
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Callable, List, Union
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
|
@ -17,6 +18,7 @@ from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
|
@ -186,7 +188,6 @@ def main():
|
||||||
help="only gpt2 now",
|
help="only gpt2 now",
|
||||||
)
|
)
|
||||||
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||||
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model_type == "gpt2":
|
if args.model_type == "gpt2":
|
||||||
|
@ -250,6 +251,12 @@ def main():
|
||||||
pad_token_id=data_builder.tokenizer.pad_token_id,
|
pad_token_id=data_builder.tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
init_ctx = (
|
||||||
|
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||||
|
if isinstance(plugin, (GeminiPlugin))
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
with init_ctx:
|
||||||
if model_name == "gpt2":
|
if model_name == "gpt2":
|
||||||
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
|
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -8,9 +9,11 @@ from transformers import AutoConfig, OPTForCausalLM
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
@ -62,14 +65,6 @@ def main():
|
||||||
if args.mem_cap > 0:
|
if args.mem_cap > 0:
|
||||||
colo_memory_cap(args.mem_cap)
|
colo_memory_cap(args.mem_cap)
|
||||||
|
|
||||||
# Build OPT model
|
|
||||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
|
||||||
model = OPTForCausalLM(config=config)
|
|
||||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
|
||||||
|
|
||||||
# Enable gradient checkpointing
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
# Set plugin
|
# Set plugin
|
||||||
booster_kwargs = {}
|
booster_kwargs = {}
|
||||||
if args.plugin == "torch_ddp_fp16":
|
if args.plugin == "torch_ddp_fp16":
|
||||||
|
@ -82,6 +77,19 @@ def main():
|
||||||
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||||
|
|
||||||
|
# Build OPT model
|
||||||
|
init_ctx = (
|
||||||
|
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||||
|
if isinstance(plugin, (GeminiPlugin))
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||||
|
with init_ctx:
|
||||||
|
model = OPTForCausalLM(config=config)
|
||||||
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||||
|
|
||||||
|
# Enable gradient checkpointing
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)
|
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
@ -8,9 +10,11 @@ from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_s
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
|
@ -78,14 +82,6 @@ def main():
|
||||||
datasets.utils.logging.set_verbosity_error()
|
datasets.utils.logging.set_verbosity_error()
|
||||||
transformers.utils.logging.set_verbosity_error()
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
# Build OPT model
|
|
||||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
|
||||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
|
||||||
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
|
||||||
|
|
||||||
# Enable gradient checkpointing
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
# Set plugin
|
# Set plugin
|
||||||
booster_kwargs = {}
|
booster_kwargs = {}
|
||||||
if args.plugin == "torch_ddp_fp16":
|
if args.plugin == "torch_ddp_fp16":
|
||||||
|
@ -110,6 +106,21 @@ def main():
|
||||||
|
|
||||||
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||||
|
|
||||||
|
# Build OPT model
|
||||||
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||||
|
# Build OPT model
|
||||||
|
init_ctx = (
|
||||||
|
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||||
|
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
with init_ctx:
|
||||||
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
||||||
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||||
|
|
||||||
|
# Enable gradient checkpointing
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
# Prepare tokenizer and dataloader
|
# Prepare tokenizer and dataloader
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||||
dataset = NetflixDataset(tokenizer)
|
dataset = NetflixDataset(tokenizer)
|
||||||
|
|
Loading…
Reference in New Issue