Browse Source

[example] update opt example using booster api (#3918)

pull/3929/head
Baizhou Zhang 1 year ago committed by GitHub
parent
commit
e417dd004e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 32
      examples/language/opt/README.md
  2. 120
      examples/language/opt/args.py
  3. 21
      examples/language/opt/benchmark.sh
  4. 37
      examples/language/opt/data.py
  5. 146
      examples/language/opt/opt_benchmark.py
  6. 149
      examples/language/opt/opt_train_demo.py
  7. 2
      examples/language/opt/requirements.txt
  8. 30
      examples/language/opt/run_benchmark.sh
  9. 44
      examples/language/opt/run_demo.sh
  10. 28
      examples/language/opt/run_gemini.sh
  11. 19
      examples/language/opt/test_ci.sh
  12. 233
      examples/language/opt/train_gemini_opt.py

32
examples/language/opt/README.md

@ -19,15 +19,35 @@ Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/fa
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
## Quick Start
You can launch training by using the following bash script
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization).
We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
## Run Demo
By running the following script:
```bash
bash ./run_gemini.sh
bash run_demo.sh
```
You will finetune a [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) model on this [dataset](https://huggingface.co/datasets/hugginglearners/netflix-shows), which contains more than 8000 comments on Netflix shows.
The script can be modified if you want to try another set of hyperparameters or change to another OPT model with different size.
The demo code is adapted from this [blog](https://medium.com/geekculture/fine-tune-eleutherai-gpt-neo-to-generate-netflix-movie-descriptions-in-only-47-lines-of-code-40c9b4c32475) and the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
## Run Benchmark
You can run benchmark for OPT model by running the following script:
```bash
bash run_benchmark.sh
```
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing.

120
examples/language/opt/args.py

@ -0,0 +1,120 @@
from colossalai import get_default_parser
def parse_demo_args():
parser = get_default_parser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="facebook/opt-350m",
help="Path to pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument(
"--output_path",
type=str,
default="./output_model.bin",
help="The path of your saved model after finetuning."
)
parser.add_argument(
"--plugin",
type=str,
default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
)
parser.add_argument(
"--num_epoch",
type=int,
default=10,
help="Number of epochs."
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--warmup_ratio",
type=float,
default=0.1,
help="Ratio of warmup steps against total training steps."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.01,
help="Weight decay to use."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
)
args = parser.parse_args()
return args
def parse_benchmark_args():
parser = get_default_parser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="facebook/opt-125m",
help="Path to pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument(
"--plugin",
type=str,
default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.0,
help="Weight decay to use."
)
parser.add_argument(
"--max_train_steps",
type=int,
default=20,
help="Total number of training steps to perform."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
)
parser.add_argument(
"--mem_cap",
type=int,
default=0,
help="Limit on the usage of space for each GPU (in GB)."
)
args = parser.parse_args()
return args

21
examples/language/opt/benchmark.sh

@ -1,21 +0,0 @@
export BS=16
export MEMCAP=0
export MODEL="6.7b"
export GPUNUM=1
for MODEL in "6.7b" "13b" "1.3b"
do
for GPUNUM in 8 1
do
for BS in 16 24 32 8
do
for MEMCAP in 0 40
do
pkill -9 torchrun
pkill -9 python
env BS=$BS MEM_CAP=$MEMCAP MODEL=$MODEL GPUNUM=$GPUNUM bash ./run_gemini.sh
done
done
done
done

37
examples/language/opt/data.py

@ -0,0 +1,37 @@
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
class NetflixDataset(Dataset):
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = tokenizer
self.input_ids = []
self.attn_masks = []
self.labels = []
self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")['description']
self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions])
for txt in self.txt_list:
encodings_dict = self.tokenizer('</s>' + txt + '</s>',
truncation=True,
max_length=self.max_length,
padding="max_length")
self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.attn_masks[idx]
def netflix_collator(data):
return {'input_ids': torch.stack([x[0] for x in data]),
'attention_mask': torch.stack([x[1] for x in data]),
'labels': torch.stack([x[0] for x in data])}

146
examples/language/opt/opt_benchmark.py

@ -0,0 +1,146 @@
import time
import torch
import transformers
from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version
import tqdm
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from args import parse_benchmark_args
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
def format_num(num: int, bytes=False):
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
factor = 1024 if bytes else 1000
suffix = "B" if bytes else ""
for unit in ["", " K", " M", " G", " T", " P"]:
if num < factor:
return f"{num:.2f}{unit}{suffix}"
num /= factor
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print(f"Limiting GPU memory usage to {size_in_GB} GB")
def main():
args = parse_benchmark_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size
# Manage loggers
disable_existing_loggers()
logger = get_dist_logger()
if coordinator.is_master():
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# Whether to set limit of memory capacity
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# Build OPT model
# Initialize the model under ColoInitContext if using GeminiPlugin
config = AutoConfig.from_pretrained(args.model_name_or_path)
if args.plugin == 'gemini':
shard_pg = ProcessGroup(tp_degree=world_size)
default_dist_spec = ShardSpec([-1], [world_size])
with ColoInitContext(device='cpu',
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer)
SEQ_LEN = 1024
VOCAB_SIZE = 50257
# Start training.
logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize()
model.train()
start_time = time.time()
for _ in range(args.max_train_steps):
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
loss = outputs['loss']
booster.backward(loss, optimizer)
optimizer.step()
torch.cuda.synchronize()
progress_bar.update(1)
# Compute Statistics
end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
logger.info(f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
if __name__ == "__main__":
main()

149
examples/language/opt/opt_train_demo.py

@ -0,0 +1,149 @@
import time
import torch
import datasets
import transformers
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
from transformers import get_linear_schedule_with_warmup
from transformers.utils.versions import require_version
from tqdm import tqdm
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from args import parse_demo_args
from data import NetflixDataset, netflix_collator
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
torch.cuda.synchronize()
model.train()
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar:
# Foward
optimizer.zero_grad()
batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(use_cache=False, **batch)
loss = outputs['loss']
# Backward
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
# Print batch loss
pbar.set_postfix({'loss': loss.item()})
def main():
args = parse_demo_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size
# Manage loggers
disable_existing_loggers()
logger = get_dist_logger()
if coordinator.is_master():
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Build OPT model
# Initialize the model under ColoInitContext if using GeminiPlugin
config = AutoConfig.from_pretrained(args.model_name_or_path)
if args.plugin == 'gemini':
shard_pg = ProcessGroup(tp_degree=world_size)
default_dist_spec = ShardSpec([-1], [world_size])
with ColoInitContext(device='cpu',
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set plugin
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer)
dataloader = plugin.prepare_dataloader(dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=netflix_collator)
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
# Set lr scheduler
total_steps = len(dataloader) * args.num_epoch
num_warmup_steps = int(args.warmup_ratio * total_steps)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=len(dataloader) * args.num_epoch
)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
# Start finetuning
logger.info(f"Start finetuning", ranks=[0])
for epoch in range(args.num_epoch):
train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator)
# Finish training and evaluate
logger.info(f"Finish finetuning", ranks=[0])
booster.save_model(model, args.output_path)
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
if __name__ == "__main__":
main()

2
examples/language/opt/requirements.txt

@ -1,2 +1,4 @@
colossalai >= 0.1.12
torch >= 1.8.1
datasets >= 1.8.0
transformers >= 4.20.0

30
examples/language/opt/run_benchmark.sh

@ -0,0 +1,30 @@
set -xe
pip install -r requirements.txt
export BS=32
export MEMCAP=0
export GPUNUM=1
# acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`
export MODEL="125m"
for BS in 8 32 128
do
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
do
for GPUNUM in 1 4
do
MODLE_PATH="facebook/opt-${MODEL}"
torchrun \
--standalone \
--nproc_per_node ${GPUNUM} \
opt_benchmark.py \
--model_name_or_path ${MODLE_PATH} \
--mem_cap ${MEMCAP} \
--plugin ${PLUGIN} \
--batch_size ${BS}
done
done
done

44
examples/language/opt/run_demo.sh

@ -0,0 +1,44 @@
set -xe
pip install -r requirements.txt
# model name or path
MODEL="facebook/opt-350m"
# path for saving model
OUTPUT_PATH="./output_model.bin"
# plugin(training strategy)
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"
PLUGIN="gemini"
# number of gpus to use
GPUNUM=4
# batch size per gpu
BS=16
# learning rate
LR="5e-5"
# number of epoch
EPOCH=10
# weight decay
WEIGHT_DECAY=0.01
# ratio of warmup steps
WARMUP_RATIO=0.1
# run the script for demo
torchrun \
--standalone \
--nproc_per_node ${GPUNUM} \
opt_train_demo.py \
--model_name_or_path ${MODEL} \
--output_path ${OUTPUT_PATH} \
--plugin ${PLUGIN} \
--batch_size ${BS} \
--num_epoch ${EPOCH} \
--learning_rate ${LR} \
--weight_decay ${WEIGHT_DECAY} \
--warmup_ratio ${WARMUP_RATIO}

28
examples/language/opt/run_gemini.sh

@ -1,28 +0,0 @@
set -x
export BS=${BS:-16}
export MEMCAP=${MEMCAP:-0}
# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
export MODEL=${MODEL:-"125m"}
export GPUNUM=${GPUNUM:-1}
export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"}
# make directory for logs
mkdir -p ./logs
if [ ${USE_SHARD_INIT} = "true" ]; then
USE_SHARD_INIT="--shardinit"
else
USE_SHARD_INIT=""
fi
export MODLE_PATH="facebook/opt-${MODEL}"
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
torchrun \
--nproc_per_node ${GPUNUM} \
--master_port 19198 \
train_gemini_opt.py \
--mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE_PATH} \
${USE_SHARD_INIT} \
--batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log

19
examples/language/opt/test_ci.sh

@ -1,4 +1,19 @@
for GPUNUM in 2 1
set -xe
pip install -r requirements.txt
BS=4
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
do
env BS=2 MODEL="125m" GPUNUM=$GPUNUM bash ./run_gemini.sh
for GPUNUM in 1 4
do
torchrun \
--standalone \
--nproc_per_node ${GPUNUM} \
opt_benchmark.py \
--model_name_or_path "facebook/opt-125m" \
--plugin ${PLUGIN} \
--batch_size ${BS}
done
done

233
examples/language/opt/train_gemini_opt.py

@ -1,233 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
on a text file or a dataset without using HuggingFace Trainer.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import time
from functools import partial
import datasets
import torch
import torch.distributed as dist
import transformers
from transformers import CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def get_time_stamp():
torch.cuda.synchronize()
return time.time()
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument(
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument(
"--max_train_steps",
type=int,
default=20,
help="Total number of training steps to perform.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--model_type",
type=str,
default=None,
help="Model type to use if training from scratch.",
choices=MODEL_TYPES,
)
parser.add_argument(
"--shardinit",
action="store_true",
help="Initialize the model with tensor parallel",
)
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
args = parser.parse_args()
return args
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB))
def main():
args = parse_args()
disable_existing_loggers()
colossalai.launch_from_torch({})
logger = get_dist_logger()
is_main_process = dist.get_rank() == 0
if is_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# If passed along, set the training seed now.
if args.seed is not None:
torch.mannul_seed(args.seed)
logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}")
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if args.config_name:
config = AutoConfig.from_pretrained(args.config_name)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(args.model_name_or_path)
else:
config = CONFIG_MAPPING[args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
logger.info("Model config has been created", ranks=[0])
if args.init_in_cpu:
init_dev = torch.device('cpu')
else:
init_dev = get_current_device()
# shard init parameters
if args.shardinit:
logger.info("Sharding initialization !", ranks=[0])
else:
logger.info("Skipping sharding initialization", ranks=[0])
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
# build model
if args.model_name_or_path is None:
logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
local_files_only=False)
# enable gradient checkpointing
model.gradient_checkpointing_enable()
numel = sum([p.numel() for p in model.parameters()])
PLACEMENT_POLICY = 'cpu'
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=PLACEMENT_POLICY,
pin_memory=True,
strict_ddp_mode=args.shardinit)
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
SEQ_LEN = 1024
VOCAB_SIZE = 50257
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
model.train()
for step in range(args.max_train_steps):
st_time = time.time()
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
loss = outputs['loss']
optimizer.backward(loss)
optimizer.step()
optimizer.zero_grad()
torch.cuda.synchronize()
step_time = time.time() - st_time
step_tflops = get_tflops_func(step_time)
logger.info("step {} finished, Tflops {}".format(step, step_tflops), ranks=[0])
logger.info("Training finished", ranks=[0])
if __name__ == "__main__":
main()
Loading…
Cancel
Save