|
|
@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
|
|
|
|
|
|
|
import datasets
|
|
|
|
import datasets
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import transformers
|
|
|
|
import transformers
|
|
|
@ -8,9 +10,11 @@ from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_s
|
|
|
|
from transformers.utils.versions import require_version
|
|
|
|
from transformers.utils.versions import require_version
|
|
|
|
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
import colossalai
|
|
|
|
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
|
|
from colossalai.booster import Booster
|
|
|
|
from colossalai.booster import Booster
|
|
|
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
|
|
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
|
|
|
from colossalai.cluster import DistCoordinator
|
|
|
|
from colossalai.cluster import DistCoordinator
|
|
|
|
|
|
|
|
from colossalai.lazy import LazyInitContext
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
|
|
|
|
|
|
|
@ -78,14 +82,6 @@ def main():
|
|
|
|
datasets.utils.logging.set_verbosity_error()
|
|
|
|
datasets.utils.logging.set_verbosity_error()
|
|
|
|
transformers.utils.logging.set_verbosity_error()
|
|
|
|
transformers.utils.logging.set_verbosity_error()
|
|
|
|
|
|
|
|
|
|
|
|
# Build OPT model
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
|
|
|
|
|
|
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
|
|
|
|
|
|
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Enable gradient checkpointing
|
|
|
|
|
|
|
|
model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Set plugin
|
|
|
|
# Set plugin
|
|
|
|
booster_kwargs = {}
|
|
|
|
booster_kwargs = {}
|
|
|
|
if args.plugin == "torch_ddp_fp16":
|
|
|
|
if args.plugin == "torch_ddp_fp16":
|
|
|
@ -110,6 +106,21 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
|
|
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Build OPT model
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
|
|
|
|
|
|
|
# Build OPT model
|
|
|
|
|
|
|
|
init_ctx = (
|
|
|
|
|
|
|
|
LazyInitContext(default_device=get_accelerator().get_current_device())
|
|
|
|
|
|
|
|
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
|
|
|
|
|
|
|
else nullcontext()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
with init_ctx:
|
|
|
|
|
|
|
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
|
|
|
|
|
|
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Enable gradient checkpointing
|
|
|
|
|
|
|
|
model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
|
|
|
|
# Prepare tokenizer and dataloader
|
|
|
|
# Prepare tokenizer and dataloader
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
|
|
|
dataset = NetflixDataset(tokenizer)
|
|
|
|
dataset = NetflixDataset(tokenizer)
|
|
|
|