mirror of https://github.com/hpcaitech/ColossalAI
[example] update opt example using booster api (#3918)
parent
9166988d9b
commit
e417dd004e
|
@ -19,15 +19,35 @@ Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/fa
|
||||||
|
|
||||||
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
|
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
|
||||||
|
|
||||||
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
|
|
||||||
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
|
|
||||||
|
|
||||||
## Our Modifications
|
## Our Modifications
|
||||||
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
|
|
||||||
|
|
||||||
## Quick Start
|
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
|
||||||
You can launch training by using the following bash script
|
the tokenization).
|
||||||
|
|
||||||
|
We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
|
||||||
|
|
||||||
|
## Run Demo
|
||||||
|
|
||||||
|
By running the following script:
|
||||||
```bash
|
```bash
|
||||||
bash ./run_gemini.sh
|
bash run_demo.sh
|
||||||
```
|
```
|
||||||
|
You will finetune a [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) model on this [dataset](https://huggingface.co/datasets/hugginglearners/netflix-shows), which contains more than 8000 comments on Netflix shows.
|
||||||
|
|
||||||
|
The script can be modified if you want to try another set of hyperparameters or change to another OPT model with different size.
|
||||||
|
|
||||||
|
The demo code is adapted from this [blog](https://medium.com/geekculture/fine-tune-eleutherai-gpt-neo-to-generate-netflix-movie-descriptions-in-only-47-lines-of-code-40c9b4c32475) and the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Run Benchmark
|
||||||
|
|
||||||
|
You can run benchmark for OPT model by running the following script:
|
||||||
|
```bash
|
||||||
|
bash run_benchmark.sh
|
||||||
|
```
|
||||||
|
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
from colossalai import get_default_parser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_demo_args():
|
||||||
|
|
||||||
|
parser = get_default_parser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default="facebook/opt-350m",
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path",
|
||||||
|
type=str,
|
||||||
|
default="./output_model.bin",
|
||||||
|
help="The path of your saved model after finetuning."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="gemini",
|
||||||
|
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_epoch",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of epochs."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size (per dp group) for the training dataloader."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-5,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup_ratio",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="Ratio of warmup steps against total training steps."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight_decay",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="Weight decay to use."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="A seed for reproducible training."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def parse_benchmark_args():
|
||||||
|
|
||||||
|
parser = get_default_parser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default="facebook/opt-125m",
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="gemini",
|
||||||
|
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size (per dp group) for the training dataloader."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-5,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--weight_decay",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Weight decay to use."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_train_steps",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Total number of training steps to perform."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="A seed for reproducible training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mem_cap",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Limit on the usage of space for each GPU (in GB)."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
|
@ -1,21 +0,0 @@
|
||||||
export BS=16
|
|
||||||
export MEMCAP=0
|
|
||||||
export MODEL="6.7b"
|
|
||||||
export GPUNUM=1
|
|
||||||
|
|
||||||
for MODEL in "6.7b" "13b" "1.3b"
|
|
||||||
do
|
|
||||||
for GPUNUM in 8 1
|
|
||||||
do
|
|
||||||
for BS in 16 24 32 8
|
|
||||||
do
|
|
||||||
for MEMCAP in 0 40
|
|
||||||
do
|
|
||||||
pkill -9 torchrun
|
|
||||||
pkill -9 python
|
|
||||||
|
|
||||||
env BS=$BS MEM_CAP=$MEMCAP MODEL=$MODEL GPUNUM=$GPUNUM bash ./run_gemini.sh
|
|
||||||
done
|
|
||||||
done
|
|
||||||
done
|
|
||||||
done
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class NetflixDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, tokenizer):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.input_ids = []
|
||||||
|
self.attn_masks = []
|
||||||
|
self.labels = []
|
||||||
|
self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")['description']
|
||||||
|
self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions])
|
||||||
|
|
||||||
|
for txt in self.txt_list:
|
||||||
|
encodings_dict = self.tokenizer('</s>' + txt + '</s>',
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding="max_length")
|
||||||
|
self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
|
||||||
|
self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.input_ids)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.input_ids[idx], self.attn_masks[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def netflix_collator(data):
|
||||||
|
return {'input_ids': torch.stack([x[0] for x in data]),
|
||||||
|
'attention_mask': torch.stack([x[1] for x in data]),
|
||||||
|
'labels': torch.stack([x[0] for x in data])}
|
|
@ -0,0 +1,146 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoConfig, OPTForCausalLM
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.zero import ColoInitContext
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
|
from args import parse_benchmark_args
|
||||||
|
|
||||||
|
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||||||
|
|
||||||
|
|
||||||
|
def format_num(num: int, bytes=False):
|
||||||
|
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
|
||||||
|
factor = 1024 if bytes else 1000
|
||||||
|
suffix = "B" if bytes else ""
|
||||||
|
for unit in ["", " K", " M", " G", " T", " P"]:
|
||||||
|
if num < factor:
|
||||||
|
return f"{num:.2f}{unit}{suffix}"
|
||||||
|
num /= factor
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(batch_size, seq_len, vocab_size):
|
||||||
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
return input_ids, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def colo_memory_cap(size_in_GB):
|
||||||
|
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
||||||
|
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||||
|
if size_in_GB * (1024**3) < cuda_capacity:
|
||||||
|
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
||||||
|
print(f"Limiting GPU memory usage to {size_in_GB} GB")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
args = parse_benchmark_args()
|
||||||
|
|
||||||
|
# Launch ColossalAI
|
||||||
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
world_size = coordinator.world_size
|
||||||
|
|
||||||
|
# Manage loggers
|
||||||
|
disable_existing_loggers()
|
||||||
|
logger = get_dist_logger()
|
||||||
|
if coordinator.is_master():
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# Whether to set limit of memory capacity
|
||||||
|
if args.mem_cap > 0:
|
||||||
|
colo_memory_cap(args.mem_cap)
|
||||||
|
|
||||||
|
# Build OPT model
|
||||||
|
# Initialize the model under ColoInitContext if using GeminiPlugin
|
||||||
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||||
|
if args.plugin == 'gemini':
|
||||||
|
shard_pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
default_dist_spec = ShardSpec([-1], [world_size])
|
||||||
|
with ColoInitContext(device='cpu',
|
||||||
|
default_dist_spec=default_dist_spec,
|
||||||
|
default_pg=shard_pg):
|
||||||
|
model = OPTForCausalLM(config)
|
||||||
|
else:
|
||||||
|
model = OPTForCausalLM(config)
|
||||||
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||||
|
|
||||||
|
# Enable gradient checkpointing
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# Set plugin
|
||||||
|
booster_kwargs = {}
|
||||||
|
if args.plugin == 'torch_ddp_fp16':
|
||||||
|
booster_kwargs['mixed_precision'] = 'fp16'
|
||||||
|
if args.plugin.startswith('torch_ddp'):
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
elif args.plugin == 'gemini':
|
||||||
|
plugin = GeminiPlugin(device=get_current_device(),
|
||||||
|
placement_policy='cpu',
|
||||||
|
pin_memory=True,
|
||||||
|
strict_ddp_mode=True,
|
||||||
|
initial_scale=2**5)
|
||||||
|
elif args.plugin == 'low_level_zero':
|
||||||
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||||
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||||
|
|
||||||
|
# Set optimizer
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)
|
||||||
|
|
||||||
|
# Set booster
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
model, optimizer, _, _, _ = booster.boost(model, optimizer)
|
||||||
|
|
||||||
|
SEQ_LEN = 1024
|
||||||
|
VOCAB_SIZE = 50257
|
||||||
|
|
||||||
|
# Start training.
|
||||||
|
logger.info(f"Start testing", ranks=[0])
|
||||||
|
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
model.train()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for _ in range(args.max_train_steps):
|
||||||
|
|
||||||
|
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
|
||||||
|
loss = outputs['loss']
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
progress_bar.update(1)
|
||||||
|
|
||||||
|
# Compute Statistics
|
||||||
|
end_time = time.time()
|
||||||
|
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
|
||||||
|
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
|
||||||
|
|
||||||
|
logger.info(f"Testing finished, "
|
||||||
|
f"batch size per gpu: {args.batch_size}, "
|
||||||
|
f"plugin: {args.plugin}, "
|
||||||
|
f"throughput: {throughput}, "
|
||||||
|
f"maximum memory usage per gpu: {max_mem}.",
|
||||||
|
ranks=[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,149 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
|
||||||
|
from transformers import get_linear_schedule_with_warmup
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.zero import ColoInitContext
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
|
from args import parse_demo_args
|
||||||
|
from data import NetflixDataset, netflix_collator
|
||||||
|
|
||||||
|
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
|
||||||
|
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_cuda(batch, device):
|
||||||
|
return {k: v.to(device) for k, v in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
|
||||||
|
|
||||||
|
for batch in pbar:
|
||||||
|
|
||||||
|
# Foward
|
||||||
|
optimizer.zero_grad()
|
||||||
|
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||||||
|
|
||||||
|
outputs = model(use_cache=False, **batch)
|
||||||
|
loss = outputs['loss']
|
||||||
|
|
||||||
|
# Backward
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
# Print batch loss
|
||||||
|
pbar.set_postfix({'loss': loss.item()})
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
args = parse_demo_args()
|
||||||
|
|
||||||
|
# Launch ColossalAI
|
||||||
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
world_size = coordinator.world_size
|
||||||
|
|
||||||
|
# Manage loggers
|
||||||
|
disable_existing_loggers()
|
||||||
|
logger = get_dist_logger()
|
||||||
|
if coordinator.is_master():
|
||||||
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# Build OPT model
|
||||||
|
# Initialize the model under ColoInitContext if using GeminiPlugin
|
||||||
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||||
|
if args.plugin == 'gemini':
|
||||||
|
shard_pg = ProcessGroup(tp_degree=world_size)
|
||||||
|
default_dist_spec = ShardSpec([-1], [world_size])
|
||||||
|
with ColoInitContext(device='cpu',
|
||||||
|
default_dist_spec=default_dist_spec,
|
||||||
|
default_pg=shard_pg):
|
||||||
|
model = OPTForCausalLM(config)
|
||||||
|
else:
|
||||||
|
model = OPTForCausalLM(config)
|
||||||
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||||||
|
|
||||||
|
# Enable gradient checkpointing
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# Set plugin
|
||||||
|
booster_kwargs = {}
|
||||||
|
if args.plugin == 'torch_ddp_fp16':
|
||||||
|
booster_kwargs['mixed_precision'] = 'fp16'
|
||||||
|
if args.plugin.startswith('torch_ddp'):
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
elif args.plugin == 'gemini':
|
||||||
|
plugin = GeminiPlugin(device=get_current_device(),
|
||||||
|
placement_policy='cpu',
|
||||||
|
pin_memory=True,
|
||||||
|
strict_ddp_mode=True,
|
||||||
|
initial_scale=2**5)
|
||||||
|
elif args.plugin == 'low_level_zero':
|
||||||
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||||||
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||||||
|
|
||||||
|
# Prepare tokenizer and dataloader
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||||
|
dataset = NetflixDataset(tokenizer)
|
||||||
|
dataloader = plugin.prepare_dataloader(dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=netflix_collator)
|
||||||
|
|
||||||
|
# Set optimizer
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
|
||||||
|
|
||||||
|
# Set lr scheduler
|
||||||
|
total_steps = len(dataloader) * args.num_epoch
|
||||||
|
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
||||||
|
lr_scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
num_training_steps=len(dataloader) * args.num_epoch
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set booster
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
dataloader=dataloader,
|
||||||
|
lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
|
# Start finetuning
|
||||||
|
logger.info(f"Start finetuning", ranks=[0])
|
||||||
|
for epoch in range(args.num_epoch):
|
||||||
|
train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator)
|
||||||
|
|
||||||
|
# Finish training and evaluate
|
||||||
|
logger.info(f"Finish finetuning", ranks=[0])
|
||||||
|
booster.save_model(model, args.output_path)
|
||||||
|
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,2 +1,4 @@
|
||||||
colossalai >= 0.1.12
|
colossalai >= 0.1.12
|
||||||
torch >= 1.8.1
|
torch >= 1.8.1
|
||||||
|
datasets >= 1.8.0
|
||||||
|
transformers >= 4.20.0
|
|
@ -0,0 +1,30 @@
|
||||||
|
set -xe
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
export BS=32
|
||||||
|
export MEMCAP=0
|
||||||
|
export GPUNUM=1
|
||||||
|
|
||||||
|
# acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`
|
||||||
|
export MODEL="125m"
|
||||||
|
|
||||||
|
for BS in 8 32 128
|
||||||
|
do
|
||||||
|
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
|
||||||
|
do
|
||||||
|
for GPUNUM in 1 4
|
||||||
|
do
|
||||||
|
|
||||||
|
MODLE_PATH="facebook/opt-${MODEL}"
|
||||||
|
torchrun \
|
||||||
|
--standalone \
|
||||||
|
--nproc_per_node ${GPUNUM} \
|
||||||
|
opt_benchmark.py \
|
||||||
|
--model_name_or_path ${MODLE_PATH} \
|
||||||
|
--mem_cap ${MEMCAP} \
|
||||||
|
--plugin ${PLUGIN} \
|
||||||
|
--batch_size ${BS}
|
||||||
|
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
|
@ -0,0 +1,44 @@
|
||||||
|
set -xe
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# model name or path
|
||||||
|
MODEL="facebook/opt-350m"
|
||||||
|
|
||||||
|
# path for saving model
|
||||||
|
OUTPUT_PATH="./output_model.bin"
|
||||||
|
|
||||||
|
# plugin(training strategy)
|
||||||
|
# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"
|
||||||
|
PLUGIN="gemini"
|
||||||
|
|
||||||
|
# number of gpus to use
|
||||||
|
GPUNUM=4
|
||||||
|
|
||||||
|
# batch size per gpu
|
||||||
|
BS=16
|
||||||
|
|
||||||
|
# learning rate
|
||||||
|
LR="5e-5"
|
||||||
|
|
||||||
|
# number of epoch
|
||||||
|
EPOCH=10
|
||||||
|
|
||||||
|
# weight decay
|
||||||
|
WEIGHT_DECAY=0.01
|
||||||
|
|
||||||
|
# ratio of warmup steps
|
||||||
|
WARMUP_RATIO=0.1
|
||||||
|
|
||||||
|
# run the script for demo
|
||||||
|
torchrun \
|
||||||
|
--standalone \
|
||||||
|
--nproc_per_node ${GPUNUM} \
|
||||||
|
opt_train_demo.py \
|
||||||
|
--model_name_or_path ${MODEL} \
|
||||||
|
--output_path ${OUTPUT_PATH} \
|
||||||
|
--plugin ${PLUGIN} \
|
||||||
|
--batch_size ${BS} \
|
||||||
|
--num_epoch ${EPOCH} \
|
||||||
|
--learning_rate ${LR} \
|
||||||
|
--weight_decay ${WEIGHT_DECAY} \
|
||||||
|
--warmup_ratio ${WARMUP_RATIO}
|
|
@ -1,28 +0,0 @@
|
||||||
set -x
|
|
||||||
export BS=${BS:-16}
|
|
||||||
export MEMCAP=${MEMCAP:-0}
|
|
||||||
# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
|
|
||||||
export MODEL=${MODEL:-"125m"}
|
|
||||||
export GPUNUM=${GPUNUM:-1}
|
|
||||||
export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"}
|
|
||||||
|
|
||||||
# make directory for logs
|
|
||||||
mkdir -p ./logs
|
|
||||||
|
|
||||||
if [ ${USE_SHARD_INIT} = "true" ]; then
|
|
||||||
USE_SHARD_INIT="--shardinit"
|
|
||||||
else
|
|
||||||
USE_SHARD_INIT=""
|
|
||||||
fi
|
|
||||||
|
|
||||||
export MODLE_PATH="facebook/opt-${MODEL}"
|
|
||||||
|
|
||||||
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
|
||||||
torchrun \
|
|
||||||
--nproc_per_node ${GPUNUM} \
|
|
||||||
--master_port 19198 \
|
|
||||||
train_gemini_opt.py \
|
|
||||||
--mem_cap ${MEMCAP} \
|
|
||||||
--model_name_or_path ${MODLE_PATH} \
|
|
||||||
${USE_SHARD_INIT} \
|
|
||||||
--batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
|
|
|
@ -1,4 +1,19 @@
|
||||||
for GPUNUM in 2 1
|
set -xe
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
BS=4
|
||||||
|
for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
|
||||||
do
|
do
|
||||||
env BS=2 MODEL="125m" GPUNUM=$GPUNUM bash ./run_gemini.sh
|
for GPUNUM in 1 4
|
||||||
|
do
|
||||||
|
|
||||||
|
torchrun \
|
||||||
|
--standalone \
|
||||||
|
--nproc_per_node ${GPUNUM} \
|
||||||
|
opt_benchmark.py \
|
||||||
|
--model_name_or_path "facebook/opt-125m" \
|
||||||
|
--plugin ${PLUGIN} \
|
||||||
|
--batch_size ${BS}
|
||||||
|
|
||||||
|
done
|
||||||
done
|
done
|
||||||
|
|
|
@ -1,233 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
|
|
||||||
on a text file or a dataset without using HuggingFace Trainer.
|
|
||||||
|
|
||||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
|
||||||
https://huggingface.co/models?filter=text-generation
|
|
||||||
"""
|
|
||||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
|
||||||
|
|
||||||
import time
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import transformers
|
|
||||||
from transformers import CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, OPTForCausalLM
|
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
||||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP
|
|
||||||
|
|
||||||
|
|
||||||
def get_data(batch_size, seq_len, vocab_size):
|
|
||||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
return input_ids, attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
|
||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
||||||
|
|
||||||
|
|
||||||
def get_time_stamp():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return time.time()
|
|
||||||
|
|
||||||
|
|
||||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
|
||||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = colossalai.get_default_parser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_name_or_path",
|
|
||||||
type=str,
|
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
||||||
required=True,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_name",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Pretrained config name or path if not the same as model_name",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_size",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Batch size (per dp group) for the training dataloader.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--learning_rate",
|
|
||||||
type=float,
|
|
||||||
default=5e-5,
|
|
||||||
help="Initial learning rate (after the potential warmup period) to use.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_train_steps",
|
|
||||||
type=int,
|
|
||||||
default=20,
|
|
||||||
help="Total number of training steps to perform.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_type",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Model type to use if training from scratch.",
|
|
||||||
choices=MODEL_TYPES,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--shardinit",
|
|
||||||
action="store_true",
|
|
||||||
help="Initialize the model with tensor parallel",
|
|
||||||
)
|
|
||||||
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
|
|
||||||
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def colo_memory_cap(size_in_GB):
|
|
||||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
|
||||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
|
||||||
if size_in_GB * (1024**3) < cuda_capacity:
|
|
||||||
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
|
||||||
print("Using {} GB of GPU memory".format(size_in_GB))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = parse_args()
|
|
||||||
disable_existing_loggers()
|
|
||||||
colossalai.launch_from_torch({})
|
|
||||||
logger = get_dist_logger()
|
|
||||||
is_main_process = dist.get_rank() == 0
|
|
||||||
|
|
||||||
if is_main_process:
|
|
||||||
datasets.utils.logging.set_verbosity_warning()
|
|
||||||
transformers.utils.logging.set_verbosity_info()
|
|
||||||
else:
|
|
||||||
datasets.utils.logging.set_verbosity_error()
|
|
||||||
transformers.utils.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
if args.mem_cap > 0:
|
|
||||||
colo_memory_cap(args.mem_cap)
|
|
||||||
|
|
||||||
# If passed along, set the training seed now.
|
|
||||||
if args.seed is not None:
|
|
||||||
torch.mannul_seed(args.seed)
|
|
||||||
logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}")
|
|
||||||
|
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
|
||||||
|
|
||||||
# Load pretrained model
|
|
||||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
|
||||||
# download model & vocab.
|
|
||||||
if args.config_name:
|
|
||||||
config = AutoConfig.from_pretrained(args.config_name)
|
|
||||||
elif args.model_name_or_path:
|
|
||||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
|
||||||
else:
|
|
||||||
config = CONFIG_MAPPING[args.model_type]()
|
|
||||||
logger.warning("You are instantiating a new config instance from scratch.")
|
|
||||||
logger.info("Model config has been created", ranks=[0])
|
|
||||||
|
|
||||||
if args.init_in_cpu:
|
|
||||||
init_dev = torch.device('cpu')
|
|
||||||
else:
|
|
||||||
init_dev = get_current_device()
|
|
||||||
|
|
||||||
# shard init parameters
|
|
||||||
if args.shardinit:
|
|
||||||
logger.info("Sharding initialization !", ranks=[0])
|
|
||||||
else:
|
|
||||||
logger.info("Skipping sharding initialization", ranks=[0])
|
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
|
||||||
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
|
||||||
|
|
||||||
# build model
|
|
||||||
if args.model_name_or_path is None:
|
|
||||||
logger.info("Train a new model from scratch", ranks=[0])
|
|
||||||
with ColoInitContext(device=init_dev,
|
|
||||||
dtype=torch.half,
|
|
||||||
default_dist_spec=default_dist_spec,
|
|
||||||
default_pg=shard_pg):
|
|
||||||
model = OPTForCausalLM(config)
|
|
||||||
else:
|
|
||||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
|
||||||
with ColoInitContext(device=init_dev,
|
|
||||||
dtype=torch.half,
|
|
||||||
default_dist_spec=default_dist_spec,
|
|
||||||
default_pg=shard_pg):
|
|
||||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
|
||||||
config=config,
|
|
||||||
local_files_only=False)
|
|
||||||
|
|
||||||
# enable gradient checkpointing
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
numel = sum([p.numel() for p in model.parameters()])
|
|
||||||
PLACEMENT_POLICY = 'cpu'
|
|
||||||
model = GeminiDDP(model,
|
|
||||||
device=get_current_device(),
|
|
||||||
placement_policy=PLACEMENT_POLICY,
|
|
||||||
pin_memory=True,
|
|
||||||
strict_ddp_mode=args.shardinit)
|
|
||||||
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
|
|
||||||
|
|
||||||
SEQ_LEN = 1024
|
|
||||||
VOCAB_SIZE = 50257
|
|
||||||
|
|
||||||
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
for step in range(args.max_train_steps):
|
|
||||||
st_time = time.time()
|
|
||||||
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
|
|
||||||
|
|
||||||
outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
|
|
||||||
loss = outputs['loss']
|
|
||||||
optimizer.backward(loss)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
step_time = time.time() - st_time
|
|
||||||
step_tflops = get_tflops_func(step_time)
|
|
||||||
|
|
||||||
logger.info("step {} finished, Tflops {}".format(step, step_tflops), ranks=[0])
|
|
||||||
|
|
||||||
logger.info("Training finished", ranks=[0])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Loading…
Reference in New Issue