mirror of https://github.com/hpcaitech/ColossalAI
Add handson to ColossalAI. (#1896)
Co-authored-by: Boxiang Wang <boxiang.wang1@gmail.com>pull/1898/head
parent
0ef8154bfa
commit
d9bf83e084
|
@ -0,0 +1,27 @@
|
|||
# Handson 1: Multi-dimensional Parallelism with Colossal-AI
|
||||
|
||||
|
||||
## Install Colossal-AI and other dependencies
|
||||
|
||||
```bash
|
||||
sh install.sh
|
||||
```
|
||||
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
|
||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
||||
|
||||
```bash
|
||||
export DATA=/path/to/data
|
||||
```
|
||||
|
||||
|
||||
## Run on 2*2 device mesh
|
||||
|
||||
Current configuration setting on `config.py` is TP=2, PP=2.
|
||||
|
||||
```bash
|
||||
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||
```
|
|
@ -0,0 +1,36 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
# hyperparameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 256
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 10
|
||||
WARMUP_EPOCHS = 3
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 512
|
||||
DEPTH = 4
|
||||
NUM_HEADS = 4
|
||||
MLP_RATIO = 2
|
||||
NUM_CLASSES = 1000
|
||||
CHECKPOINT = False
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||
|
||||
# parallel setting
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
parallel = dict(
|
||||
pipeline=2,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
# pipeline config
|
||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
|
@ -0,0 +1,4 @@
|
|||
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
pip install colossalai==0.1.10+torch1.12cu11.3 -f https://release.colossalai.org
|
||||
pip install titans
|
||||
colossalai check -i
|
|
@ -0,0 +1,116 @@
|
|||
import os
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.utils import is_using_pp, get_dataloader
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from titans.model.vit.vit import _create_vit_model
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
|
||||
|
||||
def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from torch
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
logger.info("initialized distributed environment", ranks=[0])
|
||||
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
use_pipeline = is_using_pp()
|
||||
|
||||
# create model
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
||||
patch_size=gpc.config.PATCH_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
depth=gpc.config.DEPTH,
|
||||
num_heads=gpc.config.NUM_HEADS,
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
num_classes=10,
|
||||
init_method='jax',
|
||||
checkpoint=gpc.config.CHECKPOINT)
|
||||
|
||||
if use_pipeline:
|
||||
pipelinable = PipelinableContext()
|
||||
with pipelinable:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
pipelinable.to_layer_list()
|
||||
pipelinable.policy = "uniform"
|
||||
model = pipelinable.partition(
|
||||
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
else:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
|
||||
# count number of parameters
|
||||
total_numel = 0
|
||||
for p in model.parameters():
|
||||
total_numel += p.numel()
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_stage = 0
|
||||
else:
|
||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
logger.info(
|
||||
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||
|
||||
# create dataloaders
|
||||
root = os.environ.get('DATA', '../data/cifar10')
|
||||
train_dataloader, test_dataloader = build_cifar(
|
||||
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(
|
||||
), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
# training
|
||||
engine.train()
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
||||
progress = tqdm(range(len(train_dataloader)), desc=description)
|
||||
else:
|
||||
progress = range(len(train_dataloader))
|
||||
for _ in progress:
|
||||
engine.zero_grad()
|
||||
engine.execute_schedule(data_iter, return_output_label=False)
|
||||
engine.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,20 @@
|
|||
# Handson 2: Sequence Parallelism with BERT
|
||||
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
|
||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
||||
|
||||
```bash
|
||||
export DATA=/path/to/data
|
||||
```
|
||||
|
||||
|
||||
## Run on 2*2 device mesh
|
||||
|
||||
Current configuration setting on `config.py` is TP=2, PP=2.
|
||||
|
||||
```bash
|
||||
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||
```
|
|
@ -0,0 +1,35 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
# hyperparameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 256
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 10
|
||||
WARMUP_EPOCHS = 3
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 512
|
||||
DEPTH = 4
|
||||
NUM_HEADS = 4
|
||||
MLP_RATIO = 2
|
||||
NUM_CLASSES = 1000
|
||||
CHECKPOINT = False
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||
|
||||
# parallel setting
|
||||
TENSOR_PARALLEL_SIZE = 1
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
parallel = dict(
|
||||
tensor=dict(size=4, mode='sequence')
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
# pipeline config
|
||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
|
@ -0,0 +1,116 @@
|
|||
import os
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.utils import is_using_pp, get_dataloader
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from titans.model.vit.vit import _create_vit_model
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
|
||||
|
||||
def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from torch
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
logger.info("initialized distributed environment", ranks=[0])
|
||||
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
use_pipeline = is_using_pp()
|
||||
|
||||
# create model
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
||||
patch_size=gpc.config.PATCH_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
depth=gpc.config.DEPTH,
|
||||
num_heads=gpc.config.NUM_HEADS,
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
num_classes=10,
|
||||
init_method='jax',
|
||||
checkpoint=gpc.config.CHECKPOINT)
|
||||
|
||||
if use_pipeline:
|
||||
pipelinable = PipelinableContext()
|
||||
with pipelinable:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
pipelinable.to_layer_list()
|
||||
pipelinable.policy = "uniform"
|
||||
model = pipelinable.partition(
|
||||
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
else:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
|
||||
# count number of parameters
|
||||
total_numel = 0
|
||||
for p in model.parameters():
|
||||
total_numel += p.numel()
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_stage = 0
|
||||
else:
|
||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
logger.info(
|
||||
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||
|
||||
# create dataloaders
|
||||
root = os.environ.get('DATA', '../data/cifar10')
|
||||
train_dataloader, test_dataloader = build_cifar(
|
||||
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(
|
||||
), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
# training
|
||||
engine.train()
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
||||
progress = tqdm(range(len(train_dataloader)), desc=description)
|
||||
else:
|
||||
progress = range(len(train_dataloader))
|
||||
for _ in progress:
|
||||
engine.zero_grad()
|
||||
engine.execute_schedule(data_iter, return_output_label=False)
|
||||
engine.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,4 +1,4 @@
|
|||
# Train ResNet on CIFAR10 with auto_parallel
|
||||
# Handson 3: Auto-Parallelism with ResNet
|
||||
|
||||
## Prepare Dataset
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# Handson 4: Comparison of Large Batch Training Optimization
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
|
||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
||||
|
||||
```bash
|
||||
export DATA=/path/to/data
|
||||
```
|
||||
|
||||
|
||||
## Run on 2*2 device mesh
|
||||
|
||||
```bash
|
||||
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||
```
|
|
@ -0,0 +1,36 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
# hyperparameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 512
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 10
|
||||
WARMUP_EPOCHS = 3
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 512
|
||||
DEPTH = 4
|
||||
NUM_HEADS = 4
|
||||
MLP_RATIO = 2
|
||||
NUM_CLASSES = 1000
|
||||
CHECKPOINT = False
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||
|
||||
# parallel setting
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
parallel = dict(
|
||||
pipeline=2,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
# pipeline config
|
||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import Lars, Lamb
|
||||
from colossalai.utils import is_using_pp, get_dataloader
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from titans.model.vit.vit import _create_vit_model
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
|
||||
|
||||
def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from torch
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
logger.info("initialized distributed environment", ranks=[0])
|
||||
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
use_pipeline = is_using_pp()
|
||||
|
||||
# create model
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
||||
patch_size=gpc.config.PATCH_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
depth=gpc.config.DEPTH,
|
||||
num_heads=gpc.config.NUM_HEADS,
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
num_classes=10,
|
||||
init_method='jax',
|
||||
checkpoint=gpc.config.CHECKPOINT)
|
||||
|
||||
if use_pipeline:
|
||||
pipelinable = PipelinableContext()
|
||||
with pipelinable:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
pipelinable.to_layer_list()
|
||||
pipelinable.policy = "uniform"
|
||||
model = pipelinable.partition(
|
||||
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
else:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
|
||||
# count number of parameters
|
||||
total_numel = 0
|
||||
for p in model.parameters():
|
||||
total_numel += p.numel()
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_stage = 0
|
||||
else:
|
||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
logger.info(
|
||||
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||
|
||||
# create dataloaders
|
||||
root = os.environ.get('DATA', '../data/cifar10')
|
||||
train_dataloader, test_dataloader = build_cifar(
|
||||
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE,
|
||||
weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
# training
|
||||
engine.train()
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
||||
progress = tqdm(range(len(train_dataloader)), desc=description)
|
||||
else:
|
||||
progress = range(len(train_dataloader))
|
||||
for _ in progress:
|
||||
engine.zero_grad()
|
||||
engine.execute_schedule(data_iter, return_output_label=False)
|
||||
engine.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
# Handson 5: Fine-tuning and Serving for OPT from Hugging Face
|
|
@ -0,0 +1,77 @@
|
|||
# Overview
|
||||
|
||||
This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI.
|
||||
|
||||
It supports tensor parallelism, batching and caching.
|
||||
|
||||
# How to run
|
||||
|
||||
Run OPT-125M:
|
||||
```shell
|
||||
python opt_fastapi.py opt-125m
|
||||
```
|
||||
|
||||
It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs.
|
||||
|
||||
## Configure
|
||||
|
||||
### Configure model
|
||||
```shell
|
||||
python opt_fastapi.py <model>
|
||||
```
|
||||
Available models: opt-125m, opt-6.7b, opt-30b, opt-175b.
|
||||
|
||||
### Configure tensor parallelism
|
||||
```shell
|
||||
python opt_fastapi.py <model> --tp <TensorParallelismWorldSize>
|
||||
```
|
||||
The `<TensorParallelismWorldSize>` can be an integer in `[1, #GPUs]`. Default `1`.
|
||||
|
||||
### Configure checkpoint
|
||||
```shell
|
||||
python opt_fastapi.py <model> --checkpoint <CheckpointPath>
|
||||
```
|
||||
The `<CheckpointPath>` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded.
|
||||
|
||||
### Configure queue
|
||||
```shell
|
||||
python opt_fastapi.py <model> --queue_size <QueueSize>
|
||||
```
|
||||
The `<QueueSize>` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406).
|
||||
|
||||
### Configure bathcing
|
||||
```shell
|
||||
python opt_fastapi.py <model> --max_batch_size <MaxBatchSize>
|
||||
```
|
||||
The `<MaxBatchSize>` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value.
|
||||
|
||||
Note that the batch size is not always equal to `<MaxBatchSize>`, as some consecutive requests may not be batched.
|
||||
|
||||
### Configure caching
|
||||
```shell
|
||||
python opt_fastapi.py <model> --cache_size <CacheSize> --cache_list_size <CacheListSize>
|
||||
```
|
||||
This will cache `<CacheSize>` unique requests. And for each unique request, it cache `<CacheListSize>` different results. A random result will be returned if the cache is hit.
|
||||
|
||||
The `<CacheSize>` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `<CacheListSize>` can be an integer in `[1, MAXINT]`.
|
||||
|
||||
### Other configurations
|
||||
```shell
|
||||
python opt_fastapi.py -h
|
||||
```
|
||||
|
||||
# How to benchmark
|
||||
```shell
|
||||
cd benchmark
|
||||
locust
|
||||
```
|
||||
|
||||
Then open the web interface link which is on your console.
|
||||
|
||||
# Pre-process pre-trained weights
|
||||
|
||||
## OPT-66B
|
||||
See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py).
|
||||
|
||||
## OPT-175B
|
||||
See [script/process-opt-175b](./script/process-opt-175b/).
|
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
from typing import List, Deque, Tuple, Hashable, Any
|
||||
from energonai import BatchManager, SubmitEntry, TaskEntry
|
||||
|
||||
|
||||
class BatchManagerForGeneration(BatchManager):
|
||||
def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
def _left_padding(self, batch_inputs):
|
||||
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
|
||||
outputs = {'input_ids': [], 'attention_mask': []}
|
||||
for inputs in batch_inputs:
|
||||
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
|
||||
padding_len = max_len - len(input_ids)
|
||||
input_ids = [self.pad_token_id] * padding_len + input_ids
|
||||
attention_mask = [0] * padding_len + attention_mask
|
||||
outputs['input_ids'].append(input_ids)
|
||||
outputs['attention_mask'].append(attention_mask)
|
||||
for k in outputs:
|
||||
outputs[k] = torch.tensor(outputs[k])
|
||||
return outputs, max_len
|
||||
|
||||
@staticmethod
|
||||
def _make_batch_key(entry: SubmitEntry) -> tuple:
|
||||
data = entry.data
|
||||
return (data['top_k'], data['top_p'], data['temperature'])
|
||||
|
||||
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
|
||||
entry = q.popleft()
|
||||
uids = [entry.uid]
|
||||
batch = [entry.data]
|
||||
while len(batch) < self.max_batch_size:
|
||||
if len(q) == 0:
|
||||
break
|
||||
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
|
||||
break
|
||||
if q[0].data['max_tokens'] > entry.data['max_tokens']:
|
||||
break
|
||||
e = q.popleft()
|
||||
batch.append(e.data)
|
||||
uids.append(e.uid)
|
||||
inputs, max_len = self._left_padding(batch)
|
||||
trunc_lens = []
|
||||
for data in batch:
|
||||
trunc_lens.append(max_len + data['max_tokens'])
|
||||
inputs['top_k'] = entry.data['top_k']
|
||||
inputs['top_p'] = entry.data['top_p']
|
||||
inputs['temperature'] = entry.data['temperature']
|
||||
inputs['max_tokens'] = max_len + entry.data['max_tokens']
|
||||
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
|
||||
|
||||
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
|
||||
retval = []
|
||||
for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens):
|
||||
retval.append((uid, output[:trunc_len]))
|
||||
return retval
|
|
@ -0,0 +1,15 @@
|
|||
from locust import HttpUser, task
|
||||
from json import JSONDecodeError
|
||||
|
||||
|
||||
class GenerationUser(HttpUser):
|
||||
@task
|
||||
def generate(self):
|
||||
prompt = 'Question: What is the longest river on the earth? Answer:'
|
||||
for i in range(4, 9):
|
||||
data = {'max_tokens': 2**i, 'prompt': prompt}
|
||||
with self.client.post('/generation', json=data, catch_response=True) as response:
|
||||
if response.status_code in (200, 406):
|
||||
response.success()
|
||||
else:
|
||||
response.failure('Response wrong')
|
|
@ -0,0 +1,64 @@
|
|||
from collections import OrderedDict
|
||||
from threading import Lock
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Any, Hashable, Dict
|
||||
|
||||
|
||||
class MissCacheError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ListCache:
|
||||
def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None:
|
||||
"""Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied.
|
||||
When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed.
|
||||
|
||||
Args:
|
||||
cache_size (int): Max size for LRU cache.
|
||||
list_size (int): Value list size.
|
||||
fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to [].
|
||||
"""
|
||||
self.cache_size = cache_size
|
||||
self.list_size = list_size
|
||||
self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict()
|
||||
self.fixed_cache: Dict[Hashable, List[Any]] = {}
|
||||
for key in fixed_keys:
|
||||
self.fixed_cache[key] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def get(self, key: Hashable) -> List[Any]:
|
||||
with self.lock():
|
||||
if key in self.fixed_cache:
|
||||
l = self.fixed_cache[key]
|
||||
if len(l) >= self.list_size:
|
||||
return l
|
||||
elif key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
l = self.cache[key]
|
||||
if len(l) >= self.list_size:
|
||||
return l
|
||||
raise MissCacheError()
|
||||
|
||||
def add(self, key: Hashable, value: Any) -> None:
|
||||
with self.lock():
|
||||
if key in self.fixed_cache:
|
||||
l = self.fixed_cache[key]
|
||||
if len(l) < self.list_size and value not in l:
|
||||
l.append(value)
|
||||
elif key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
l = self.cache[key]
|
||||
if len(l) < self.list_size and value not in l:
|
||||
l.append(value)
|
||||
else:
|
||||
if len(self.cache) >= self.cache_size:
|
||||
self.cache.popitem(last=False)
|
||||
self.cache[key] = [value]
|
||||
|
||||
@contextmanager
|
||||
def lock(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
self._lock.release()
|
|
@ -0,0 +1,123 @@
|
|||
import argparse
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from energonai import QueueFullError, launch_engine
|
||||
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
async def generate(data: GenerationTaskReq, request: Request):
|
||||
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
|
||||
key = (data.prompt, data.max_tokens)
|
||||
try:
|
||||
if cache is None:
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = data.max_tokens
|
||||
inputs['top_k'] = data.top_k
|
||||
inputs['top_p'] = data.top_p
|
||||
inputs['temperature'] = data.temperature
|
||||
try:
|
||||
uid = id(data)
|
||||
engine.submit(uid, inputs)
|
||||
output = await engine.wait(uid)
|
||||
output = tokenizer.decode(output, skip_special_tokens=True)
|
||||
if cache is not None:
|
||||
cache.add(key, output)
|
||||
except QueueFullError as e:
|
||||
raise HTTPException(status_code=406, detail=e.args[0])
|
||||
|
||||
return {'text': output}
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown(*_):
|
||||
engine.shutdown()
|
||||
server.should_exit = True
|
||||
server.force_exit = True
|
||||
await server.shutdown()
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
||||
server = uvicorn.Server(config=config)
|
||||
server.run()
|
|
@ -0,0 +1,122 @@
|
|||
import logging
|
||||
import argparse
|
||||
import random
|
||||
from torch import Tensor
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B
|
||||
from transformers import GPT2Tokenizer
|
||||
from energonai import launch_engine, QueueFullError
|
||||
from sanic import Sanic
|
||||
from sanic.request import Request
|
||||
from sanic.response import json
|
||||
from sanic_ext import validate, openapi
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
|
||||
|
||||
app = Sanic('opt')
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
@openapi.body(GenerationTaskReq)
|
||||
@validate(json=GenerationTaskReq)
|
||||
async def generate(request: Request, body: GenerationTaskReq):
|
||||
logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}')
|
||||
key = (body.prompt, body.max_tokens)
|
||||
try:
|
||||
if cache is None:
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(body.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = body.max_tokens
|
||||
inputs['top_k'] = body.top_k
|
||||
inputs['top_p'] = body.top_p
|
||||
inputs['temperature'] = body.temperature
|
||||
try:
|
||||
uid = id(body)
|
||||
engine.submit(uid, inputs)
|
||||
output = await engine.wait(uid)
|
||||
assert isinstance(output, Tensor)
|
||||
output = tokenizer.decode(output, skip_special_tokens=True)
|
||||
if cache is not None:
|
||||
cache.add(key, output)
|
||||
except QueueFullError as e:
|
||||
return json({'detail': e.args[0]}, status=406)
|
||||
|
||||
return json({'text': output})
|
||||
|
||||
|
||||
@app.after_server_stop
|
||||
def shutdown(*_):
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
app.run(args.http_host, args.http_port)
|
|
@ -0,0 +1,8 @@
|
|||
fastapi==0.85.1
|
||||
locust==2.11.0
|
||||
pydantic==1.10.2
|
||||
sanic==22.9.0
|
||||
sanic_ext==22.9.0
|
||||
torch>=1.10.0
|
||||
transformers==4.23.1
|
||||
uvicorn==0.19.0
|
|
@ -0,0 +1,46 @@
|
|||
# Process OPT-175B weights
|
||||
|
||||
You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this.
|
||||
|
||||
First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`.
|
||||
|
||||
Then, `cd metaseq`.
|
||||
|
||||
To consolidate checkpoints to eliminate FSDP:
|
||||
|
||||
```shell
|
||||
bash metaseq/scripts/reshard_mp_launch_no_slurm.sh <directory_where_all_the_shards_are>/checkpoint_last <output_dir>/ 8 1
|
||||
```
|
||||
|
||||
You will get 8 files in `<output_dir>`, and you should have the following checksums:
|
||||
```
|
||||
7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt
|
||||
c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt
|
||||
45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt
|
||||
abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt
|
||||
05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt
|
||||
d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt
|
||||
fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt
|
||||
2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt
|
||||
```
|
||||
|
||||
Copy `flat-meta.json` to `<output_dir>`.
|
||||
|
||||
Then cd to this dir, and we unflatten parameters.
|
||||
|
||||
```shell
|
||||
bash unflat.sh <output_dir>/ <new_output_dir>/
|
||||
```
|
||||
|
||||
Finally, you will get 8 files in `<new_output_dir>` with following checksums:
|
||||
```
|
||||
6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt
|
||||
58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt
|
||||
69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt
|
||||
002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt
|
||||
6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt
|
||||
93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt
|
||||
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
|
||||
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
|
||||
```
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def load_json(path: str):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def parse_shape_info(flat_dir: str):
|
||||
data = load_json(os.path.join(flat_dir, 'shape.json'))
|
||||
flat_info = defaultdict(lambda: defaultdict(list))
|
||||
for k, shape in data.items():
|
||||
matched = re.match(r'decoder.layers.\d+', k)
|
||||
if matched is None:
|
||||
flat_key = 'flat_param_0'
|
||||
else:
|
||||
flat_key = f'{matched[0]}.flat_param_0'
|
||||
flat_info[flat_key]['names'].append(k)
|
||||
flat_info[flat_key]['shapes'].append(shape)
|
||||
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
|
||||
return flat_info
|
||||
|
||||
|
||||
def convert(flat_dir: str, output_dir: str, part: int):
|
||||
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
|
||||
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
|
||||
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
|
||||
flat_sd = torch.load(flat_path)
|
||||
print(f'Loaded flat state dict from {flat_path}')
|
||||
output_sd = {}
|
||||
for flat_key, param_meta in flat_meta.items():
|
||||
flat_param = flat_sd['model'][flat_key]
|
||||
assert sum(param_meta['numels']) == flat_param.numel(
|
||||
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
|
||||
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
|
||||
output_sd[name] = param.view(shape)
|
||||
|
||||
torch.save(output_sd, output_path)
|
||||
print(f'Saved unflat state dict to {output_path}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('flat_dir')
|
||||
parser.add_argument('output_dir')
|
||||
parser.add_argument('part', type=int)
|
||||
args = parser.parse_args()
|
||||
convert(args.flat_dir, args.output_dir, args.part)
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
for i in $(seq 0 7); do
|
||||
python convert_ckpt.py $1 $2 ${i} &
|
||||
done
|
||||
|
||||
wait $(jobs -p)
|
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
import torch
|
||||
from multiprocessing import Pool
|
||||
|
||||
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
|
||||
# you can use whether wget or git lfs
|
||||
|
||||
path = "/path/to/your/ckpt"
|
||||
new_path = "/path/to/the/processed/ckpt/"
|
||||
|
||||
assert os.path.isdir(path)
|
||||
files = []
|
||||
for filename in os.listdir(path):
|
||||
filepath = os.path.join(path, filename)
|
||||
if os.path.isfile(filepath):
|
||||
files.append(filepath)
|
||||
|
||||
with Pool(14) as pool:
|
||||
ckpts = pool.map(torch.load, files)
|
||||
|
||||
restored = {}
|
||||
for ckpt in ckpts:
|
||||
for k,v in ckpt.items():
|
||||
if(k[0] == 'm'):
|
||||
k = k[6:]
|
||||
if(k == "lm_head.weight"):
|
||||
k = "head.dense.weight"
|
||||
if(k == "decoder.final_layer_norm.weight"):
|
||||
k = "decoder.layer_norm.weight"
|
||||
if(k == "decoder.final_layer_norm.bias"):
|
||||
k = "decoder.layer_norm.bias"
|
||||
restored[k] = v
|
||||
restored["decoder.version"] = "0.0"
|
||||
|
||||
|
||||
split_num = len(restored.keys()) // 60
|
||||
count = 0
|
||||
file_count = 1
|
||||
tmp = {}
|
||||
for k,v in restored.items():
|
||||
print(k)
|
||||
tmp[k] = v
|
||||
count = count + 1
|
||||
if(count == split_num):
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
file_count = file_count + 1
|
||||
count = 0
|
||||
tmp = {}
|
||||
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
# Train OPT model with Colossal-AI
|
||||
|
||||
## OPT
|
||||
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
|
||||
|
||||
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
|
||||
|
||||
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
|
||||
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
|
||||
|
||||
## Our Modifications
|
||||
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
|
||||
|
||||
## Quick Start
|
||||
You can launch training by using the following bash script
|
||||
|
||||
```bash
|
||||
bash ./run_clm.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
|
||||
```
|
||||
|
||||
- batch-size-per-gpu: number of samples fed to each GPU, default is 16
|
||||
- mem-cap: limit memory usage within a value in GB, default is 0 (no limit)
|
||||
- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request
|
||||
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
|
||||
- gpu-num: the number of GPUs to use, default is 1.
|
||||
|
||||
## Remarkable Performance
|
||||
On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
|
||||
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT.png" width=1000/>
|
||||
</p>
|
||||
|
||||
Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI!
|
||||
|
||||
More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d),
|
||||
and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon.
|
|
@ -0,0 +1,21 @@
|
|||
export BS=16
|
||||
export MEMCAP=0
|
||||
export MODEL="6.7b"
|
||||
export GPUNUM=1
|
||||
|
||||
for MODEL in "6.7b" "13b" "1.3b"
|
||||
do
|
||||
for GPUNUM in 8 1
|
||||
do
|
||||
for BS in 16 24 32 8
|
||||
do
|
||||
for MEMCAP in 0 40
|
||||
do
|
||||
pkill -9 torchrun
|
||||
pkill -9 python
|
||||
|
||||
bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
|
@ -0,0 +1,6 @@
|
|||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
|
||||
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
|
||||
tensor_placement_policy="auto",
|
||||
reuse_fp16_shard=True),
|
||||
optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384))
|
|
@ -0,0 +1,32 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class barrier_context():
|
||||
"""
|
||||
This context manager is used to allow one process to execute while blocking all
|
||||
other processes in the same process group. This is often useful when downloading is required
|
||||
as we only want to download in one process to prevent file corruption.
|
||||
Args:
|
||||
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||||
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
|
||||
Usage:
|
||||
with barrier_context():
|
||||
dataset = CIFAR10(root='./data', download=True)
|
||||
"""
|
||||
|
||||
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
|
||||
# the class name is lowercase by convention
|
||||
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
|
||||
self.should_block = current_rank != executor_rank
|
||||
self.group = gpc.get_group(parallel_mode=parallel_mode)
|
||||
|
||||
def __enter__(self):
|
||||
if self.should_block:
|
||||
dist.barrier(group=self.group)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
if not self.should_block:
|
||||
dist.barrier(group=self.group)
|
|
@ -0,0 +1,6 @@
|
|||
colossalai
|
||||
torch >= 1.8.1
|
||||
datasets >= 1.8.0
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
accelerate == 0.13.2
|
|
@ -0,0 +1,596 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
|
||||
on a text file or a dataset without using HuggingFace Trainer.
|
||||
|
||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||
https://huggingface.co/models?filter=text-generation
|
||||
"""
|
||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
||||
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from itertools import chain
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.utils import set_seed
|
||||
from context import barrier_context
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import colossalai
|
||||
import transformers
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.utils import get_current_device, get_dataloader
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
GPT2Tokenizer,
|
||||
OPTForCausalLM,
|
||||
SchedulerType,
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
def get_time_stamp():
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument("--train_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A csv or a json file containing the training data.")
|
||||
parser.add_argument("--validation_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A csv or a json file containing the validation data.")
|
||||
parser.add_argument(
|
||||
"--validation_split_percentage",
|
||||
default=5,
|
||||
help="The percentage of the train set used as validation set in case there's no validation split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_eval_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the evaluation dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_type",
|
||||
type=SchedulerType,
|
||||
default="linear",
|
||||
help="The scheduler type to use.",
|
||||
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||
)
|
||||
parser.add_argument("--num_warmup_steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler.")
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model type to use if training from scratch.",
|
||||
choices=MODEL_TYPES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of"
|
||||
" this size for training. Default to the model max input length for single sentence inputs (take into"
|
||||
" account special tokens)."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessing_num_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of processes to use for the preprocessing.",
|
||||
)
|
||||
parser.add_argument("--overwrite_cache",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument("--no_keep_linebreaks",
|
||||
action="store_true",
|
||||
help="Do not keep line breaks when using TXT files.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_model_id",
|
||||
type=str,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.")
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
help="Whether to enable experiment trackers for logging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="all",
|
||||
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."),
|
||||
)
|
||||
|
||||
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
|
||||
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if args.train_file is not None:
|
||||
extension = args.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
|
||||
if args.validation_file is not None:
|
||||
extension = args.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||
|
||||
if args.push_to_hub:
|
||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def colo_memory_cap(size_in_GB):
|
||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
if size_in_GB * (1024**3) < cuda_capacity:
|
||||
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
||||
print("Using {} GB of GPU memory".format(size_in_GB))
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch(config=dict())
|
||||
logger = get_dist_logger()
|
||||
is_main_process = dist.get_rank() == 0
|
||||
|
||||
if is_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
if args.mem_cap > 0:
|
||||
colo_memory_cap(args.mem_cap)
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}")
|
||||
|
||||
# Handle the repository creation
|
||||
with barrier_context():
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
logger.info("Start preparing dataset", ranks=[0])
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
split=f"train[:{args.validation_split_percentage}%]",
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
split=f"train[{args.validation_split_percentage}%:]",
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
dataset_args = {}
|
||||
if args.train_file is not None:
|
||||
data_files["train"] = args.train_file
|
||||
if args.validation_file is not None:
|
||||
data_files["validation"] = args.validation_file
|
||||
extension = args.train_file.split(".")[-1]
|
||||
if extension == "txt":
|
||||
extension = "text"
|
||||
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
|
||||
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
|
||||
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[:{args.validation_split_percentage}%]",
|
||||
**dataset_args,
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[{args.validation_split_percentage}%:]",
|
||||
**dataset_args,
|
||||
)
|
||||
logger.info("Dataset is prepared", ranks=[0])
|
||||
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
if args.config_name:
|
||||
config = AutoConfig.from_pretrained(args.config_name)
|
||||
elif args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
else:
|
||||
config = CONFIG_MAPPING[args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
logger.info("Model config has been created", ranks=[0])
|
||||
|
||||
if args.model_name_or_path == 'facebook/opt-13b':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||
else:
|
||||
print(f'load model from {args.model_name_or_path}')
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
||||
logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0])
|
||||
|
||||
if args.init_in_cpu:
|
||||
init_dev = torch.device('cpu')
|
||||
else:
|
||||
init_dev = get_current_device()
|
||||
|
||||
# build model
|
||||
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
||||
# currently, there has a bug in pretrained opt-13b
|
||||
# we can not import it until huggingface fix it
|
||||
logger.info("Train a new model from scratch", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
local_files_only=False)
|
||||
|
||||
# enable graident checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
PLACEMENT_POLICY = 'auto'
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
pg = ProcessGroup()
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
|
||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# First we tokenize all the texts.
|
||||
column_names = raw_datasets["train"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples[text_column_name])
|
||||
|
||||
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
if args.block_size is None:
|
||||
block_size = tokenizer.model_max_length
|
||||
if block_size > 1024:
|
||||
logger.warning(
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||
"Picking 1024 instead. You can change that default value by passing --block_size xxx.")
|
||||
block_size = 1024
|
||||
else:
|
||||
if args.block_size > tokenizer.model_max_length:
|
||||
logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
|
||||
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.")
|
||||
block_size = min(args.block_size, tokenizer.model_max_length)
|
||||
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
||||
def group_texts(examples):
|
||||
# Concatenate all texts.
|
||||
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
||||
# customize this part to your needs.
|
||||
if total_length >= block_size:
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i:i + block_size] for i in range(0, total_length, block_size)
|
||||
] for k, t in concatenated_examples.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
||||
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
||||
# to preprocess.
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
|
||||
lm_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
desc=f"Grouping texts in chunks of {block_size}",
|
||||
)
|
||||
|
||||
train_dataset = lm_datasets["train"]
|
||||
eval_dataset = lm_datasets["validation"]
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
# for index in random.sample(range(len(train_dataset)), 3):
|
||||
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = get_dataloader(train_dataset,
|
||||
shuffle=True,
|
||||
add_sampler=True,
|
||||
collate_fn=default_data_collator,
|
||||
batch_size=args.per_device_train_batch_size)
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
collate_fn=default_data_collator,
|
||||
batch_size=args.per_device_eval_batch_size)
|
||||
logger.info("Dataloaders have been created", ranks=[0])
|
||||
|
||||
# Optimizer
|
||||
# Split weights in two groups, one with weight decay and the other not.
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
logger.info("***** Running training *****", ranks=[0])
|
||||
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
|
||||
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process)
|
||||
completed_steps = 0
|
||||
starting_epoch = 0
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
model.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs['loss']
|
||||
optimizer.backward(loss)
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
progress_bar.update(1)
|
||||
completed_steps += 1
|
||||
|
||||
global_step += 1
|
||||
logger.info("Global step {} finished".format(global_step + 1), ranks=[0])
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
|
||||
loss = outputs['loss'].unsqueeze(0)
|
||||
losses.append(loss)
|
||||
|
||||
losses = torch.cat(losses)
|
||||
losses = losses[:len(eval_dataset)]
|
||||
try:
|
||||
eval_loss = torch.mean(losses)
|
||||
perplexity = math.exp(eval_loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
|
||||
|
||||
if args.output_dir is not None:
|
||||
model_state = model.state_dict()
|
||||
if is_main_process:
|
||||
torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
|
||||
dist.barrier()
|
||||
# load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
|
||||
# model.load_state_dict(load_state, strict=False)
|
||||
|
||||
logger.info("Training finished", ranks=[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,22 @@
|
|||
set -x
|
||||
export BS=${1:-16}
|
||||
export MEMCAP=${2:-0}
|
||||
export MODEL=${3:-"125m"}
|
||||
export GPUNUM=${4:-1}
|
||||
|
||||
# make directory for logs
|
||||
mkdir -p ./logs
|
||||
|
||||
export MODLE_PATH="facebook/opt-${MODEL}"
|
||||
|
||||
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
||||
torchrun \
|
||||
--nproc_per_node ${GPUNUM} \
|
||||
--master_port 19198 \
|
||||
run_clm.py \
|
||||
--dataset_name wikitext \
|
||||
--dataset_config_name wikitext-2-raw-v1 \
|
||||
--output_dir $PWD \
|
||||
--mem_cap ${MEMCAP} \
|
||||
--model_name_or_path ${MODLE_PATH} \
|
||||
--per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
|
|
@ -0,0 +1,16 @@
|
|||
## Overview
|
||||
This example shows how to use ColossalAI to run huggingface GPT training with Gemini and ZeRO DDP.
|
||||
|
||||
## GPT
|
||||
We use the huggingface transformers GPT2 model. The input data is randonly generated.
|
||||
|
||||
## Our Modifications
|
||||
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
|
||||
|
||||
## Quick Start
|
||||
You can launch training by using the following bash script
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
bash run.sh
|
||||
```
|
|
@ -0,0 +1,3 @@
|
|||
colossalai >= 0.1.10
|
||||
torch >= 1.8.1
|
||||
transformers >= 4.231
|
|
@ -0,0 +1 @@
|
|||
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log
|
|
@ -0,0 +1,241 @@
|
|||
from functools import partial
|
||||
from time import time
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
## Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
if param.process_group.tp_world_size() == 1:
|
||||
param.set_process_group(pg)
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
## Randomly Generated Data
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def gpt2_medium(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_xl(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_10b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
|
||||
|
||||
def get_gpu_mem():
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_mem_info(prefix=''):
|
||||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# set process group for all parameters
|
||||
param.set_process_group(pg)
|
||||
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
# keep the shape of the output from c_fc
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||
cai_version = colossalai.__version__
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=placememt_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=32)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(placememt_policy))
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
else:
|
||||
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
NUM_STEPS = 10
|
||||
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(get_mem_info(), ranks=[0])
|
||||
|
||||
# build GPT model
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = gpt2_medium(checkpoint=True)
|
||||
|
||||
numel = sum([p.numel() for p in model.parameters()])
|
||||
logger.info(f'Model numel: {numel}', ranks=[0])
|
||||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||||
|
||||
# Tensor Parallelism (TP)
|
||||
tensor_parallelize(model, pg)
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
model = gemini_zero_dpp(model, pg, args.placement)
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
# build optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
for n in range(NUM_STEPS):
|
||||
# we just use randomly generated data here
|
||||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||||
optimizer.zero_grad()
|
||||
start = time()
|
||||
outputs = model(input_ids, attn_mask)
|
||||
loss = criterion(outputs, input_ids)
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||
optimizer.backward(loss)
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||
optimizer.step()
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||||
step_time = time() - start
|
||||
logger.info(
|
||||
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
|
||||
ranks=[0])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,4 +1,5 @@
|
|||
# Stable Diffusion with Colossal-AI
|
||||
# Handson 6: Acceleration of Stable Diffusion
|
||||
|
||||
*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
|
||||
fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
|
||||
|
Before Width: | Height: | Size: 431 KiB After Width: | Height: | Size: 431 KiB |
Loading…
Reference in New Issue