From 35e22be2f6cf5dc5c85eebdb4848465d68722585 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 6 Jan 2023 10:08:41 +0800 Subject: [PATCH] [example] simplify opt example (#2344) --- examples/language/gpt/train_gpt_demo.py | 2 +- examples/language/opt/README.md | 21 +- examples/language/opt/benchmark.sh | 2 +- examples/language/opt/colossalai_zero.py | 6 - examples/language/opt/context.py | 32 -- examples/language/opt/requirements.txt | 6 - examples/language/opt/run_clm.py | 596 ---------------------- examples/language/opt/run_clm.sh | 22 - examples/language/opt/run_gemini.sh | 20 + examples/language/opt/train_gemini_opt.py | 211 ++++++++ 10 files changed, 234 insertions(+), 684 deletions(-) delete mode 100644 examples/language/opt/colossalai_zero.py delete mode 100644 examples/language/opt/context.py delete mode 100644 examples/language/opt/requirements.txt delete mode 100755 examples/language/opt/run_clm.py delete mode 100644 examples/language/opt/run_clm.sh create mode 100644 examples/language/opt/run_gemini.sh create mode 100755 examples/language/opt/train_gemini_opt.py diff --git a/examples/language/gpt/train_gpt_demo.py b/examples/language/gpt/train_gpt_demo.py index b18ff5111..ce71c6dde 100644 --- a/examples/language/gpt/train_gpt_demo.py +++ b/examples/language/gpt/train_gpt_demo.py @@ -5,7 +5,6 @@ from time import time import psutil import torch import torch.nn as nn -from model_zoo import model_builder from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP from utils import get_data, get_tflops @@ -16,6 +15,7 @@ from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext +from model_zoo import model_builder CAI_VERSION = colossalai.__version__ diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md index 75573b709..c2fd25457 100644 --- a/examples/language/opt/README.md +++ b/examples/language/opt/README.md @@ -29,24 +29,5 @@ We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP. You can launch training by using the following bash script ```bash -bash ./run_clm.sh +bash ./run_gemini.sh ``` - -- batch-size-per-gpu: number of samples fed to each GPU, default is 16 -- mem-cap: limit memory usage within a value in GB, default is 0 (no limit) -- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request -the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT). -- gpu-num: the number of GPUs to use, default is 1. - -## Remarkable Performance -On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed. -Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale. - -

- -

- -Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI! - -More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d), -and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon. diff --git a/examples/language/opt/benchmark.sh b/examples/language/opt/benchmark.sh index f02f7629a..0d04b5e9b 100644 --- a/examples/language/opt/benchmark.sh +++ b/examples/language/opt/benchmark.sh @@ -14,7 +14,7 @@ do pkill -9 torchrun pkill -9 python -bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM +env BS=$BS MEM_CAP=$MEMCAP MODEL=$MODEL GPUNUM=$GPUNUM bash ./run_gemini.sh done done done diff --git a/examples/language/opt/colossalai_zero.py b/examples/language/opt/colossalai_zero.py deleted file mode 100644 index 833745f3e..000000000 --- a/examples/language/opt/colossalai_zero.py +++ /dev/null @@ -1,6 +0,0 @@ -from colossalai.zero.shard_utils import TensorShardStrategy - -zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), - tensor_placement_policy="auto", - reuse_fp16_shard=True), - optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384)) diff --git a/examples/language/opt/context.py b/examples/language/opt/context.py deleted file mode 100644 index 95f0abf1d..000000000 --- a/examples/language/opt/context.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch.distributed as dist - -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc - - -class barrier_context(): - """ - This context manager is used to allow one process to execute while blocking all - other processes in the same process group. This is often useful when downloading is required - as we only want to download in one process to prevent file corruption. - Args: - executor_rank (int): the process rank to execute without blocking, all other processes will be blocked - parallel_mode (ParallelMode): the parallel mode corresponding to a process group - Usage: - with barrier_context(): - dataset = CIFAR10(root='./data', download=True) - """ - - def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL): - # the class name is lowercase by convention - current_rank = gpc.get_local_rank(parallel_mode=parallel_mode) - self.should_block = current_rank != executor_rank - self.group = gpc.get_group(parallel_mode=parallel_mode) - - def __enter__(self): - if self.should_block: - dist.barrier(group=self.group) - - def __exit__(self, exc_type, exc_value, exc_traceback): - if not self.should_block: - dist.barrier(group=self.group) diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt deleted file mode 100644 index c34df7992..000000000 --- a/examples/language/opt/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -colossalai -torch >= 1.8.1 -datasets >= 1.8.0 -sentencepiece != 0.1.92 -protobuf -accelerate == 0.13.2 diff --git a/examples/language/opt/run_clm.py b/examples/language/opt/run_clm.py deleted file mode 100755 index c6590323e..000000000 --- a/examples/language/opt/run_clm.py +++ /dev/null @@ -1,596 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) -on a text file or a dataset without using HuggingFace Trainer. - -Here is the full list of checkpoints on the hub that can be fine-tuned by this script: -https://huggingface.co/models?filter=text-generation -""" -# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. - -import math -import os -import time -from itertools import chain - -import datasets -import torch -import torch.distributed as dist -from accelerate.utils import set_seed -from context import barrier_context -from datasets import load_dataset -from packaging import version -from torch.utils.data import DataLoader -from tqdm.auto import tqdm - -import colossalai -import transformers -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP -from colossalai.tensor import ProcessGroup -from colossalai.utils import get_current_device, get_dataloader -from colossalai.utils.model.colo_init_context import ColoInitContext -from transformers import ( - CONFIG_MAPPING, - MODEL_MAPPING, - AutoConfig, - AutoTokenizer, - GPT2Tokenizer, - OPTForCausalLM, - SchedulerType, - default_data_collator, - get_scheduler, -) -from transformers.utils.versions import require_version - -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") - -MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) -MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) - - -def get_time_stamp(): - torch.cuda.synchronize() - return time.time() - - -def parse_args(): - parser = colossalai.get_default_parser() - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help="The name of the dataset to use (via the datasets library).", - ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The configuration name of the dataset to use (via the datasets library).", - ) - parser.add_argument("--train_file", - type=str, - default=None, - help="A csv or a json file containing the training data.") - parser.add_argument("--validation_file", - type=str, - default=None, - help="A csv or a json file containing the validation data.") - parser.add_argument( - "--validation_split_percentage", - default=5, - help="The percentage of the train set used as validation set in case there's no validation split", - ) - parser.add_argument( - "--model_name_or_path", - type=str, - help="Path to pretrained model or model identifier from huggingface.co/models.", - required=True, - ) - parser.add_argument( - "--config_name", - type=str, - default=None, - help="Pretrained config name or path if not the same as model_name", - ) - parser.add_argument( - "--tokenizer_name", - type=str, - default=None, - help="Pretrained tokenizer name or path if not the same as model_name", - ) - parser.add_argument( - "--use_slow_tokenizer", - action="store_true", - help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", - ) - parser.add_argument( - "--per_device_train_batch_size", - type=int, - default=8, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument( - "--per_device_eval_batch_size", - type=int, - default=8, - help="Batch size (per device) for the evaluation dataloader.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--lr_scheduler_type", - type=SchedulerType, - default="linear", - help="The scheduler type to use.", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], - ) - parser.add_argument("--num_warmup_steps", - type=int, - default=0, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--model_type", - type=str, - default=None, - help="Model type to use if training from scratch.", - choices=MODEL_TYPES, - ) - parser.add_argument( - "--block_size", - type=int, - default=None, - help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of" - " this size for training. Default to the model max input length for single sentence inputs (take into" - " account special tokens)."), - ) - parser.add_argument( - "--preprocessing_num_workers", - type=int, - default=None, - help="The number of processes to use for the preprocessing.", - ) - parser.add_argument("--overwrite_cache", - type=bool, - default=False, - help="Overwrite the cached training and evaluation sets") - parser.add_argument("--no_keep_linebreaks", - action="store_true", - help="Do not keep line breaks when using TXT files.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_model_id", - type=str, - help="The name of the repository to keep in sync with the local `output_dir`.") - parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--checkpointing_steps", - type=str, - default=None, - help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help="If the training should continue from a checkpoint folder.", - ) - parser.add_argument( - "--with_tracking", - action="store_true", - help="Whether to enable experiment trackers for logging.", - ) - parser.add_argument( - "--report_to", - type=str, - default="all", - help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' - "Only applicable when `--with_tracking` is passed."), - ) - - parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") - parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") - args = parser.parse_args() - - # Sanity checks - if args.dataset_name is None and args.train_file is None and args.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") - else: - if args.train_file is not None: - extension = args.train_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." - if args.validation_file is not None: - extension = args.validation_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." - - if args.push_to_hub: - assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." - - return args - - -def colo_memory_cap(size_in_GB): - from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device - cuda_capacity = colo_device_memory_capacity(get_current_device()) - if size_in_GB * (1024**3) < cuda_capacity: - colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) - print("Using {} GB of GPU memory".format(size_in_GB)) - - -def main(): - args = parse_args() - disable_existing_loggers() - colossalai.launch_from_torch(config=dict()) - logger = get_dist_logger() - is_main_process = dist.get_rank() == 0 - - if is_main_process: - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - if args.mem_cap > 0: - colo_memory_cap(args.mem_cap) - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") - - # Handle the repository creation - with barrier_context(): - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub). - # - # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called - # 'text' is found. You can easily tweak this behavior (see below). - # - # In distributed training, the load_dataset function guarantee that only one local process can concurrently - # download the dataset. - logger.info("Start preparing dataset", ranks=[0]) - if args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - args.dataset_name, - args.dataset_config_name, - split=f"train[:{args.validation_split_percentage}%]", - ) - raw_datasets["train"] = load_dataset( - args.dataset_name, - args.dataset_config_name, - split=f"train[{args.validation_split_percentage}%:]", - ) - else: - data_files = {} - dataset_args = {} - if args.train_file is not None: - data_files["train"] = args.train_file - if args.validation_file is not None: - data_files["validation"] = args.validation_file - extension = args.train_file.split(".")[-1] - if extension == "txt": - extension = "text" - dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks - raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) - # If no validation data is there, validation_split_percentage will be used to divide the dataset. - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - extension, - data_files=data_files, - split=f"train[:{args.validation_split_percentage}%]", - **dataset_args, - ) - raw_datasets["train"] = load_dataset( - extension, - data_files=data_files, - split=f"train[{args.validation_split_percentage}%:]", - **dataset_args, - ) - logger.info("Dataset is prepared", ranks=[0]) - - # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at - # https://huggingface.co/docs/datasets/loading_datasets.html. - - # Load pretrained model and tokenizer - # - # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently - # download model & vocab. - if args.config_name: - config = AutoConfig.from_pretrained(args.config_name) - elif args.model_name_or_path: - config = AutoConfig.from_pretrained(args.model_name_or_path) - else: - config = CONFIG_MAPPING[args.model_type]() - logger.warning("You are instantiating a new config instance from scratch.") - logger.info("Model config has been created", ranks=[0]) - - if args.model_name_or_path == 'facebook/opt-13b': - tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) - else: - print(f'load model from {args.model_name_or_path}') - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) - logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0]) - - if args.init_in_cpu: - init_dev = torch.device('cpu') - else: - init_dev = get_current_device() - - # build model - if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': - # currently, there has a bug in pretrained opt-13b - # we can not import it until huggingface fix it - logger.info("Train a new model from scratch", ranks=[0]) - with ColoInitContext(device=init_dev): - model = OPTForCausalLM(config) - else: - logger.info("Finetune a pre-trained model", ranks=[0]) - with ColoInitContext(device=init_dev): - model = OPTForCausalLM.from_pretrained(args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - local_files_only=False) - - # enable graident checkpointing - model.gradient_checkpointing_enable() - - PLACEMENT_POLICY = 'auto' - cai_version = colossalai.__version__ - logger.info(f'using Colossal-AI version {cai_version}') - if version.parse(cai_version) > version.parse("0.1.10"): - from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) - elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): - from colossalai.gemini import ChunkManager, GeminiManager - pg = ProcessGroup() - chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) - gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) - model = ZeroDDP(model, gemini_manager) - - logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) - - # Preprocessing the datasets. - # First we tokenize all the texts. - column_names = raw_datasets["train"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - - def tokenize_function(examples): - return tokenizer(examples[text_column_name]) - - with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): - tokenized_datasets = raw_datasets.map( - tokenize_function, - batched=True, - num_proc=args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not args.overwrite_cache, - desc="Running tokenizer on dataset", - ) - - if args.block_size is None: - block_size = tokenizer.model_max_length - if block_size > 1024: - logger.warning( - f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --block_size xxx.") - block_size = 1024 - else: - if args.block_size > tokenizer.model_max_length: - logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" - f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.") - block_size = min(args.block_size, tokenizer.model_max_length) - - # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can - # customize this part to your needs. - if total_length >= block_size: - total_length = (total_length // block_size) * block_size - # Split by chunks of max_len. - result = { - k: [t[i:i + block_size] for i in range(0, total_length, block_size) - ] for k, t in concatenated_examples.items() - } - result["labels"] = result["input_ids"].copy() - return result - - # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder - # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower - # to preprocess. - # - # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: - # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map - - with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA): - lm_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=args.preprocessing_num_workers, - load_from_cache_file=not args.overwrite_cache, - desc=f"Grouping texts in chunks of {block_size}", - ) - - train_dataset = lm_datasets["train"] - eval_dataset = lm_datasets["validation"] - - # Log a few random samples from the training set: - # for index in random.sample(range(len(train_dataset)), 3): - # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - # DataLoaders creation: - train_dataloader = get_dataloader(train_dataset, - shuffle=True, - add_sampler=True, - collate_fn=default_data_collator, - batch_size=args.per_device_train_batch_size) - eval_dataloader = DataLoader(eval_dataset, - collate_fn=default_data_collator, - batch_size=args.per_device_eval_batch_size) - logger.info("Dataloaders have been created", ranks=[0]) - - # Optimizer - # Split weights in two groups, one with weight decay and the other not. - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": args.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - - optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) - optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - lr_scheduler = get_scheduler( - name=args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=args.num_warmup_steps, - num_training_steps=args.max_train_steps, - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # Train! - total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA) - - logger.info("***** Running training *****", ranks=[0]) - logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) - logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0]) - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) - logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process) - completed_steps = 0 - starting_epoch = 0 - global_step = 0 - - for epoch in range(starting_epoch, args.num_train_epochs): - - if completed_steps >= args.max_train_steps: - break - - model.train() - for step, batch in enumerate(train_dataloader): - batch = {k: v.cuda() for k, v in batch.items()} - outputs = model(**batch) - loss = outputs['loss'] - optimizer.backward(loss) - - if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - progress_bar.update(1) - completed_steps += 1 - - global_step += 1 - logger.info("Global step {} finished".format(global_step + 1), ranks=[0]) - - if completed_steps >= args.max_train_steps: - break - - model.eval() - losses = [] - for step, batch in enumerate(eval_dataloader): - with torch.no_grad(): - batch = {k: v.cuda() for k, v in batch.items()} - outputs = model(**batch) - - loss = outputs['loss'].unsqueeze(0) - losses.append(loss) - - losses = torch.cat(losses) - losses = losses[:len(eval_dataset)] - try: - eval_loss = torch.mean(losses) - perplexity = math.exp(eval_loss) - except OverflowError: - perplexity = float("inf") - - logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0]) - - if args.output_dir is not None: - model_state = model.state_dict() - if is_main_process: - torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) - dist.barrier() - # load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) - # model.load_state_dict(load_state, strict=False) - - logger.info("Training finished", ranks=[0]) - - -if __name__ == "__main__": - main() diff --git a/examples/language/opt/run_clm.sh b/examples/language/opt/run_clm.sh deleted file mode 100644 index 858d3325a..000000000 --- a/examples/language/opt/run_clm.sh +++ /dev/null @@ -1,22 +0,0 @@ -set -x -export BS=${1:-16} -export MEMCAP=${2:-0} -export MODEL=${3:-"125m"} -export GPUNUM=${4:-1} - -# make directory for logs -mkdir -p ./logs - -export MODLE_PATH="facebook/opt-${MODEL}" - -# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 -torchrun \ - --nproc_per_node ${GPUNUM} \ - --master_port 19198 \ - run_clm.py \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --output_dir $PWD \ - --mem_cap ${MEMCAP} \ - --model_name_or_path ${MODLE_PATH} \ - --per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/language/opt/run_gemini.sh b/examples/language/opt/run_gemini.sh new file mode 100644 index 000000000..d9625723a --- /dev/null +++ b/examples/language/opt/run_gemini.sh @@ -0,0 +1,20 @@ +set -x +export BS=${BS:-16} +export MEMCAP=${MEMCAP:-0} +# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b` +export MODEL=${MODEL:-"125m"} +export GPUNUM=${GPUNUM:-1} + +# make directory for logs +mkdir -p ./logs + +export MODLE_PATH="facebook/opt-${MODEL}" + +# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + train_gemini_opt.py \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE_PATH} \ + --batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py new file mode 100755 index 000000000..64426ba42 --- /dev/null +++ b/examples/language/opt/train_gemini_opt.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import time +from functools import partial + +import datasets +import torch +import torch.distributed as dist +import transformers +from transformers import CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, OPTForCausalLM +from transformers.utils.versions import require_version + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer +from colossalai.nn.parallel import GeminiDDP +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext + + +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def get_time_stamp(): + torch.cuda.synchronize() + return time.time() + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument( + "--max_train_steps", + type=int, + default=20, + help="Total number of training steps to perform.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") + parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + args = parser.parse_args() + + return args + + +def colo_memory_cap(size_in_GB): + from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) + if size_in_GB * (1024**3) < cuda_capacity: + colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) + print("Using {} GB of GPU memory".format(size_in_GB)) + + +def main(): + args = parse_args() + disable_existing_loggers() + colossalai.launch_from_torch({}) + logger = get_dist_logger() + is_main_process = dist.get_rank() == 0 + + if is_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + if args.mem_cap > 0: + colo_memory_cap(args.mem_cap) + + # If passed along, set the training seed now. + if args.seed is not None: + torch.mannul_seed(args.seed) + logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}") + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + logger.info("Model config has been created", ranks=[0]) + + if args.init_in_cpu: + init_dev = torch.device('cpu') + else: + init_dev = get_current_device() + + # build model + if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + # currently, there has a bug in pretrained opt-13b + # we can not import it until huggingface fix it + logger.info("Train a new model from scratch", ranks=[0]) + with ColoInitContext(device=init_dev, dtype=torch.half): + model = OPTForCausalLM(config) + else: + logger.info("Finetune a pre-trained model", ranks=[0]) + with ColoInitContext(device=init_dev, dtype=torch.half): + model = OPTForCausalLM.from_pretrained(args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False) + + # enable graident checkpointing + model.gradient_checkpointing_enable() + + numel = sum([p.numel() for p in model.parameters()]) + PLACEMENT_POLICY = 'cpu' + model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) + optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) + + SEQ_LEN = 1024 + VOCAB_SIZE = 50257 + + get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) + + model.train() + for step in range(args.max_train_steps): + st_time = time.time() + input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) + + outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False) + loss = outputs['loss'] + optimizer.backward(loss) + + optimizer.step() + optimizer.zero_grad() + torch.cuda.synchronize() + step_time = time.time() - st_time + step_tflops = get_tflops_func(step_time) + + logger.info("step {} finished, Tflops {}".format(step, step_tflops), ranks=[0]) + + logger.info("Training finished", ranks=[0]) + + +if __name__ == "__main__": + main()