Add handson to ColossalAI. (#1896)

Co-authored-by: Boxiang Wang <boxiang.wang1@gmail.com>
pull/1898/head
BoxiangW 2022-11-11 03:13:22 -05:00 committed by GitHub
parent 0ef8154bfa
commit d9bf83e084
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
90 changed files with 2157 additions and 2 deletions

View File

@ -0,0 +1,27 @@
# Handson 1: Multi-dimensional Parallelism with Colossal-AI
## Install Colossal-AI and other dependencies
```bash
sh install.sh
```
## Prepare Dataset
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
```bash
export DATA=/path/to/data
```
## Run on 2*2 device mesh
Current configuration setting on `config.py` is TP=2, PP=2.
```bash
colossalai run --nproc_per_node 4 train.py --config config.py
```

View File

@ -0,0 +1,36 @@
from colossalai.amp import AMP_TYPE
# hyperparameters
# BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size
BATCH_SIZE = 256
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 10
WARMUP_EPOCHS = 3
# model config
IMG_SIZE = 224
PATCH_SIZE = 16
HIDDEN_SIZE = 512
DEPTH = 4
NUM_HEADS = 4
MLP_RATIO = 2
NUM_CLASSES = 1000
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
# parallel setting
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
parallel = dict(
pipeline=2,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0
# pipeline config
NUM_MICRO_BATCHES = parallel['pipeline']

View File

@ -0,0 +1,4 @@
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install colossalai==0.1.10+torch1.12cu11.3 -f https://release.colossalai.org
pip install titans
colossalai check -i

View File

@ -0,0 +1,116 @@
import os
import colossalai
import torch
from tqdm import tqdm
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp, get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.model.vit.vit import _create_vit_model
from titans.dataloader.cifar10 import build_cifar
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
colossalai.launch_from_torch(config=args.config)
# get logger
logger = get_dist_logger()
logger.info("initialized distributed environment", ranks=[0])
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
use_pipeline = is_using_pp()
# create model
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
patch_size=gpc.config.PATCH_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
depth=gpc.config.DEPTH,
num_heads=gpc.config.NUM_HEADS,
mlp_ratio=gpc.config.MLP_RATIO,
num_classes=10,
init_method='jax',
checkpoint=gpc.config.CHECKPOINT)
if use_pipeline:
pipelinable = PipelinableContext()
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)
# count number of parameters
total_numel = 0
for p in model.parameters():
total_numel += p.numel()
if not gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_stage = 0
else:
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
# create dataloaders
root = os.environ.get('DATA', '../data/cifar10')
train_dataloader, test_dataloader = build_cifar(
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
# create optimizer
optimizer = torch.optim.AdamW(model.parameters(
), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
# initialize
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
data_iter = iter(train_dataloader)
for epoch in range(gpc.config.NUM_EPOCHS):
# training
engine.train()
if gpc.get_global_rank() == 0:
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
progress = tqdm(range(len(train_dataloader)), desc=description)
else:
progress = range(len(train_dataloader))
for _ in progress:
engine.zero_grad()
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,20 @@
# Handson 2: Sequence Parallelism with BERT
## Prepare Dataset
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
```bash
export DATA=/path/to/data
```
## Run on 2*2 device mesh
Current configuration setting on `config.py` is TP=2, PP=2.
```bash
colossalai run --nproc_per_node 4 train.py --config config.py
```

View File

@ -0,0 +1,35 @@
from colossalai.amp import AMP_TYPE
# hyperparameters
# BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size
BATCH_SIZE = 256
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 10
WARMUP_EPOCHS = 3
# model config
IMG_SIZE = 224
PATCH_SIZE = 16
HIDDEN_SIZE = 512
DEPTH = 4
NUM_HEADS = 4
MLP_RATIO = 2
NUM_CLASSES = 1000
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
# parallel setting
TENSOR_PARALLEL_SIZE = 1
TENSOR_PARALLEL_MODE = '1d'
parallel = dict(
tensor=dict(size=4, mode='sequence')
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0
# pipeline config
NUM_MICRO_BATCHES = parallel['pipeline']

View File

@ -0,0 +1,116 @@
import os
import colossalai
import torch
from tqdm import tqdm
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp, get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.model.vit.vit import _create_vit_model
from titans.dataloader.cifar10 import build_cifar
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
colossalai.launch_from_torch(config=args.config)
# get logger
logger = get_dist_logger()
logger.info("initialized distributed environment", ranks=[0])
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
use_pipeline = is_using_pp()
# create model
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
patch_size=gpc.config.PATCH_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
depth=gpc.config.DEPTH,
num_heads=gpc.config.NUM_HEADS,
mlp_ratio=gpc.config.MLP_RATIO,
num_classes=10,
init_method='jax',
checkpoint=gpc.config.CHECKPOINT)
if use_pipeline:
pipelinable = PipelinableContext()
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)
# count number of parameters
total_numel = 0
for p in model.parameters():
total_numel += p.numel()
if not gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_stage = 0
else:
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
# create dataloaders
root = os.environ.get('DATA', '../data/cifar10')
train_dataloader, test_dataloader = build_cifar(
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
# create optimizer
optimizer = torch.optim.AdamW(model.parameters(
), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
# initialize
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
data_iter = iter(train_dataloader)
for epoch in range(gpc.config.NUM_EPOCHS):
# training
engine.train()
if gpc.get_global_rank() == 0:
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
progress = tqdm(range(len(train_dataloader)), desc=description)
else:
progress = range(len(train_dataloader))
for _ in progress:
engine.zero_grad()
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
if __name__ == '__main__':
main()

View File

@ -1,4 +1,4 @@
# Train ResNet on CIFAR10 with auto_parallel
# Handson 3: Auto-Parallelism with ResNet
## Prepare Dataset

View File

@ -0,0 +1,17 @@
# Handson 4: Comparison of Large Batch Training Optimization
## Prepare Dataset
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
```bash
export DATA=/path/to/data
```
## Run on 2*2 device mesh
```bash
colossalai run --nproc_per_node 4 train.py --config config.py
```

View File

@ -0,0 +1,36 @@
from colossalai.amp import AMP_TYPE
# hyperparameters
# BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size
BATCH_SIZE = 512
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 10
WARMUP_EPOCHS = 3
# model config
IMG_SIZE = 224
PATCH_SIZE = 16
HIDDEN_SIZE = 512
DEPTH = 4
NUM_HEADS = 4
MLP_RATIO = 2
NUM_CLASSES = 1000
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
# parallel setting
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
parallel = dict(
pipeline=2,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0
# pipeline config
NUM_MICRO_BATCHES = parallel['pipeline']

View File

@ -0,0 +1,117 @@
import os
import colossalai
import torch
from tqdm import tqdm
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import Lars, Lamb
from colossalai.utils import is_using_pp, get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.model.vit.vit import _create_vit_model
from titans.dataloader.cifar10 import build_cifar
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
colossalai.launch_from_torch(config=args.config)
# get logger
logger = get_dist_logger()
logger.info("initialized distributed environment", ranks=[0])
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
use_pipeline = is_using_pp()
# create model
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
patch_size=gpc.config.PATCH_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
depth=gpc.config.DEPTH,
num_heads=gpc.config.NUM_HEADS,
mlp_ratio=gpc.config.MLP_RATIO,
num_classes=10,
init_method='jax',
checkpoint=gpc.config.CHECKPOINT)
if use_pipeline:
pipelinable = PipelinableContext()
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)
# count number of parameters
total_numel = 0
for p in model.parameters():
total_numel += p.numel()
if not gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_stage = 0
else:
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
# create dataloaders
root = os.environ.get('DATA', '../data/cifar10')
train_dataloader, test_dataloader = build_cifar(
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
# create optimizer
optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE,
weight_decay=gpc.config.WEIGHT_DECAY)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
# initialize
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
data_iter = iter(train_dataloader)
for epoch in range(gpc.config.NUM_EPOCHS):
# training
engine.train()
if gpc.get_global_rank() == 0:
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
progress = tqdm(range(len(train_dataloader)), desc=description)
else:
progress = range(len(train_dataloader))
for _ in progress:
engine.zero_grad()
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
if __name__ == '__main__':
main()

View File

@ -0,0 +1 @@
# Handson 5: Fine-tuning and Serving for OPT from Hugging Face

View File

@ -0,0 +1,77 @@
# Overview
This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI.
It supports tensor parallelism, batching and caching.
# How to run
Run OPT-125M:
```shell
python opt_fastapi.py opt-125m
```
It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs.
## Configure
### Configure model
```shell
python opt_fastapi.py <model>
```
Available models: opt-125m, opt-6.7b, opt-30b, opt-175b.
### Configure tensor parallelism
```shell
python opt_fastapi.py <model> --tp <TensorParallelismWorldSize>
```
The `<TensorParallelismWorldSize>` can be an integer in `[1, #GPUs]`. Default `1`.
### Configure checkpoint
```shell
python opt_fastapi.py <model> --checkpoint <CheckpointPath>
```
The `<CheckpointPath>` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded.
### Configure queue
```shell
python opt_fastapi.py <model> --queue_size <QueueSize>
```
The `<QueueSize>` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406).
### Configure bathcing
```shell
python opt_fastapi.py <model> --max_batch_size <MaxBatchSize>
```
The `<MaxBatchSize>` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value.
Note that the batch size is not always equal to `<MaxBatchSize>`, as some consecutive requests may not be batched.
### Configure caching
```shell
python opt_fastapi.py <model> --cache_size <CacheSize> --cache_list_size <CacheListSize>
```
This will cache `<CacheSize>` unique requests. And for each unique request, it cache `<CacheListSize>` different results. A random result will be returned if the cache is hit.
The `<CacheSize>` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `<CacheListSize>` can be an integer in `[1, MAXINT]`.
### Other configurations
```shell
python opt_fastapi.py -h
```
# How to benchmark
```shell
cd benchmark
locust
```
Then open the web interface link which is on your console.
# Pre-process pre-trained weights
## OPT-66B
See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py).
## OPT-175B
See [script/process-opt-175b](./script/process-opt-175b/).

View File

@ -0,0 +1,59 @@
import torch
from typing import List, Deque, Tuple, Hashable, Any
from energonai import BatchManager, SubmitEntry, TaskEntry
class BatchManagerForGeneration(BatchManager):
def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.pad_token_id = pad_token_id
def _left_padding(self, batch_inputs):
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
outputs = {'input_ids': [], 'attention_mask': []}
for inputs in batch_inputs:
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
padding_len = max_len - len(input_ids)
input_ids = [self.pad_token_id] * padding_len + input_ids
attention_mask = [0] * padding_len + attention_mask
outputs['input_ids'].append(input_ids)
outputs['attention_mask'].append(attention_mask)
for k in outputs:
outputs[k] = torch.tensor(outputs[k])
return outputs, max_len
@staticmethod
def _make_batch_key(entry: SubmitEntry) -> tuple:
data = entry.data
return (data['top_k'], data['top_p'], data['temperature'])
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
entry = q.popleft()
uids = [entry.uid]
batch = [entry.data]
while len(batch) < self.max_batch_size:
if len(q) == 0:
break
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
break
if q[0].data['max_tokens'] > entry.data['max_tokens']:
break
e = q.popleft()
batch.append(e.data)
uids.append(e.uid)
inputs, max_len = self._left_padding(batch)
trunc_lens = []
for data in batch:
trunc_lens.append(max_len + data['max_tokens'])
inputs['top_k'] = entry.data['top_k']
inputs['top_p'] = entry.data['top_p']
inputs['temperature'] = entry.data['temperature']
inputs['max_tokens'] = max_len + entry.data['max_tokens']
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
retval = []
for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens):
retval.append((uid, output[:trunc_len]))
return retval

View File

@ -0,0 +1,15 @@
from locust import HttpUser, task
from json import JSONDecodeError
class GenerationUser(HttpUser):
@task
def generate(self):
prompt = 'Question: What is the longest river on the earth? Answer:'
for i in range(4, 9):
data = {'max_tokens': 2**i, 'prompt': prompt}
with self.client.post('/generation', json=data, catch_response=True) as response:
if response.status_code in (200, 406):
response.success()
else:
response.failure('Response wrong')

View File

@ -0,0 +1,64 @@
from collections import OrderedDict
from threading import Lock
from contextlib import contextmanager
from typing import List, Any, Hashable, Dict
class MissCacheError(Exception):
pass
class ListCache:
def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None:
"""Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied.
When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed.
Args:
cache_size (int): Max size for LRU cache.
list_size (int): Value list size.
fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to [].
"""
self.cache_size = cache_size
self.list_size = list_size
self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict()
self.fixed_cache: Dict[Hashable, List[Any]] = {}
for key in fixed_keys:
self.fixed_cache[key] = []
self._lock = Lock()
def get(self, key: Hashable) -> List[Any]:
with self.lock():
if key in self.fixed_cache:
l = self.fixed_cache[key]
if len(l) >= self.list_size:
return l
elif key in self.cache:
self.cache.move_to_end(key)
l = self.cache[key]
if len(l) >= self.list_size:
return l
raise MissCacheError()
def add(self, key: Hashable, value: Any) -> None:
with self.lock():
if key in self.fixed_cache:
l = self.fixed_cache[key]
if len(l) < self.list_size and value not in l:
l.append(value)
elif key in self.cache:
self.cache.move_to_end(key)
l = self.cache[key]
if len(l) < self.list_size and value not in l:
l.append(value)
else:
if len(self.cache) >= self.cache_size:
self.cache.popitem(last=False)
self.cache[key] = [value]
@contextmanager
def lock(self):
try:
self._lock.acquire()
yield
finally:
self._lock.release()

View File

@ -0,0 +1,123 @@
import argparse
import logging
import random
from typing import Optional
import uvicorn
from energonai import QueueFullError, launch_engine
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from transformers import GPT2Tokenizer
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
class GenerationTaskReq(BaseModel):
max_tokens: int = Field(gt=0, le=256, example=64)
prompt: str = Field(
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
app = FastAPI()
@app.post('/generation')
async def generate(data: GenerationTaskReq, request: Request):
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
key = (data.prompt, data.max_tokens)
try:
if cache is None:
raise MissCacheError()
outputs = cache.get(key)
output = random.choice(outputs)
logger.info('Cache hit')
except MissCacheError:
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
inputs['max_tokens'] = data.max_tokens
inputs['top_k'] = data.top_k
inputs['top_p'] = data.top_p
inputs['temperature'] = data.temperature
try:
uid = id(data)
engine.submit(uid, inputs)
output = await engine.wait(uid)
output = tokenizer.decode(output, skip_special_tokens=True)
if cache is not None:
cache.add(key, output)
except QueueFullError as e:
raise HTTPException(status_code=406, detail=e.args[0])
return {'text': output}
@app.on_event("shutdown")
async def shutdown(*_):
engine.shutdown()
server.should_exit = True
server.force_exit = True
await server.shutdown()
def get_model_fn(model_name: str):
model_map = {
'opt-125m': opt_125M,
'opt-6.7b': opt_6B,
'opt-30b': opt_30B,
'opt-175b': opt_175B
}
return model_map[model_name]
def print_args(args: argparse.Namespace):
print('\n==> Args:')
for k, v in args.__dict__.items():
print(f'{k} = {v}')
FIXED_CACHE_KEYS = [
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
parser.add_argument('--tp', type=int, default=1)
parser.add_argument('--master_host', default='localhost')
parser.add_argument('--master_port', type=int, default=19990)
parser.add_argument('--rpc_port', type=int, default=19980)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--pipe_size', type=int, default=1)
parser.add_argument('--queue_size', type=int, default=0)
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--checkpoint', default=None)
parser.add_argument('--cache_size', type=int, default=0)
parser.add_argument('--cache_list_size', type=int, default=1)
args = parser.parse_args()
print_args(args)
model_kwargs = {}
if args.checkpoint is not None:
model_kwargs['checkpoint'] = args.checkpoint
logger = logging.getLogger(__name__)
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
if args.cache_size > 0:
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
else:
cache = None
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
pad_token_id=tokenizer.pad_token_id),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs)
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
server = uvicorn.Server(config=config)
server.run()

View File

@ -0,0 +1,122 @@
import logging
import argparse
import random
from torch import Tensor
from pydantic import BaseModel, Field
from typing import Optional
from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B
from transformers import GPT2Tokenizer
from energonai import launch_engine, QueueFullError
from sanic import Sanic
from sanic.request import Request
from sanic.response import json
from sanic_ext import validate, openapi
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
class GenerationTaskReq(BaseModel):
max_tokens: int = Field(gt=0, le=256, example=64)
prompt: str = Field(
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
app = Sanic('opt')
@app.post('/generation')
@openapi.body(GenerationTaskReq)
@validate(json=GenerationTaskReq)
async def generate(request: Request, body: GenerationTaskReq):
logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}')
key = (body.prompt, body.max_tokens)
try:
if cache is None:
raise MissCacheError()
outputs = cache.get(key)
output = random.choice(outputs)
logger.info('Cache hit')
except MissCacheError:
inputs = tokenizer(body.prompt, truncation=True, max_length=512)
inputs['max_tokens'] = body.max_tokens
inputs['top_k'] = body.top_k
inputs['top_p'] = body.top_p
inputs['temperature'] = body.temperature
try:
uid = id(body)
engine.submit(uid, inputs)
output = await engine.wait(uid)
assert isinstance(output, Tensor)
output = tokenizer.decode(output, skip_special_tokens=True)
if cache is not None:
cache.add(key, output)
except QueueFullError as e:
return json({'detail': e.args[0]}, status=406)
return json({'text': output})
@app.after_server_stop
def shutdown(*_):
engine.shutdown()
def get_model_fn(model_name: str):
model_map = {
'opt-125m': opt_125M,
'opt-6.7b': opt_6B,
'opt-30b': opt_30B,
'opt-175b': opt_175B
}
return model_map[model_name]
def print_args(args: argparse.Namespace):
print('\n==> Args:')
for k, v in args.__dict__.items():
print(f'{k} = {v}')
FIXED_CACHE_KEYS = [
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
parser.add_argument('--tp', type=int, default=1)
parser.add_argument('--master_host', default='localhost')
parser.add_argument('--master_port', type=int, default=19990)
parser.add_argument('--rpc_port', type=int, default=19980)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--pipe_size', type=int, default=1)
parser.add_argument('--queue_size', type=int, default=0)
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--checkpoint', default=None)
parser.add_argument('--cache_size', type=int, default=0)
parser.add_argument('--cache_list_size', type=int, default=1)
args = parser.parse_args()
print_args(args)
model_kwargs = {}
if args.checkpoint is not None:
model_kwargs['checkpoint'] = args.checkpoint
logger = logging.getLogger(__name__)
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
if args.cache_size > 0:
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
else:
cache = None
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
pad_token_id=tokenizer.pad_token_id),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs)
app.run(args.http_host, args.http_port)

View File

@ -0,0 +1,8 @@
fastapi==0.85.1
locust==2.11.0
pydantic==1.10.2
sanic==22.9.0
sanic_ext==22.9.0
torch>=1.10.0
transformers==4.23.1
uvicorn==0.19.0

View File

@ -0,0 +1,46 @@
# Process OPT-175B weights
You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this.
First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`.
Then, `cd metaseq`.
To consolidate checkpoints to eliminate FSDP:
```shell
bash metaseq/scripts/reshard_mp_launch_no_slurm.sh <directory_where_all_the_shards_are>/checkpoint_last <output_dir>/ 8 1
```
You will get 8 files in `<output_dir>`, and you should have the following checksums:
```
7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt
c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt
45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt
abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt
05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt
d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt
fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt
2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt
```
Copy `flat-meta.json` to `<output_dir>`.
Then cd to this dir, and we unflatten parameters.
```shell
bash unflat.sh <output_dir>/ <new_output_dir>/
```
Finally, you will get 8 files in `<new_output_dir>` with following checksums:
```
6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt
58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt
69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt
002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt
6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt
93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
```

View File

@ -0,0 +1,55 @@
import argparse
import json
import os
import re
from collections import defaultdict
import numpy as np
import torch
def load_json(path: str):
with open(path) as f:
return json.load(f)
def parse_shape_info(flat_dir: str):
data = load_json(os.path.join(flat_dir, 'shape.json'))
flat_info = defaultdict(lambda: defaultdict(list))
for k, shape in data.items():
matched = re.match(r'decoder.layers.\d+', k)
if matched is None:
flat_key = 'flat_param_0'
else:
flat_key = f'{matched[0]}.flat_param_0'
flat_info[flat_key]['names'].append(k)
flat_info[flat_key]['shapes'].append(shape)
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
return flat_info
def convert(flat_dir: str, output_dir: str, part: int):
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
flat_sd = torch.load(flat_path)
print(f'Loaded flat state dict from {flat_path}')
output_sd = {}
for flat_key, param_meta in flat_meta.items():
flat_param = flat_sd['model'][flat_key]
assert sum(param_meta['numels']) == flat_param.numel(
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
output_sd[name] = param.view(shape)
torch.save(output_sd, output_path)
print(f'Saved unflat state dict to {output_path}')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('flat_dir')
parser.add_argument('output_dir')
parser.add_argument('part', type=int)
args = parser.parse_args()
convert(args.flat_dir, args.output_dir, args.part)

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,7 @@
#!/usr/bin/env sh
for i in $(seq 0 7); do
python convert_ckpt.py $1 $2 ${i} &
done
wait $(jobs -p)

View File

@ -0,0 +1,55 @@
import os
import torch
from multiprocessing import Pool
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
# you can use whether wget or git lfs
path = "/path/to/your/ckpt"
new_path = "/path/to/the/processed/ckpt/"
assert os.path.isdir(path)
files = []
for filename in os.listdir(path):
filepath = os.path.join(path, filename)
if os.path.isfile(filepath):
files.append(filepath)
with Pool(14) as pool:
ckpts = pool.map(torch.load, files)
restored = {}
for ckpt in ckpts:
for k,v in ckpt.items():
if(k[0] == 'm'):
k = k[6:]
if(k == "lm_head.weight"):
k = "head.dense.weight"
if(k == "decoder.final_layer_norm.weight"):
k = "decoder.layer_norm.weight"
if(k == "decoder.final_layer_norm.bias"):
k = "decoder.layer_norm.bias"
restored[k] = v
restored["decoder.version"] = "0.0"
split_num = len(restored.keys()) // 60
count = 0
file_count = 1
tmp = {}
for k,v in restored.items():
print(k)
tmp[k] = v
count = count + 1
if(count == split_num):
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))
file_count = file_count + 1
count = 0
tmp = {}
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))

View File

@ -0,0 +1,53 @@
<!---
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Train OPT model with Colossal-AI
## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
## Quick Start
You can launch training by using the following bash script
```bash
bash ./run_clm.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
```
- batch-size-per-gpu: number of samples fed to each GPU, default is 16
- mem-cap: limit memory usage within a value in GB, default is 0 (no limit)
- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
- gpu-num: the number of GPUs to use, default is 1.
## Remarkable Performance
On a single GPU, Colossal-AIs automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT.png" width=1000/>
</p>
Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI!
More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d),
and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon.

View File

@ -0,0 +1,21 @@
export BS=16
export MEMCAP=0
export MODEL="6.7b"
export GPUNUM=1
for MODEL in "6.7b" "13b" "1.3b"
do
for GPUNUM in 8 1
do
for BS in 16 24 32 8
do
for MEMCAP in 0 40
do
pkill -9 torchrun
pkill -9 python
bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM
done
done
done
done

View File

@ -0,0 +1,6 @@
from colossalai.zero.shard_utils import TensorShardStrategy
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
tensor_placement_policy="auto",
reuse_fp16_shard=True),
optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384))

View File

@ -0,0 +1,32 @@
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class barrier_context():
"""
This context manager is used to allow one process to execute while blocking all
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
Usage:
with barrier_context():
dataset = CIFAR10(root='./data', download=True)
"""
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
# the class name is lowercase by convention
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
self.should_block = current_rank != executor_rank
self.group = gpc.get_group(parallel_mode=parallel_mode)
def __enter__(self):
if self.should_block:
dist.barrier(group=self.group)
def __exit__(self, exc_type, exc_value, exc_traceback):
if not self.should_block:
dist.barrier(group=self.group)

View File

@ -0,0 +1,6 @@
colossalai
torch >= 1.8.1
datasets >= 1.8.0
sentencepiece != 0.1.92
protobuf
accelerate == 0.13.2

View File

@ -0,0 +1,596 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
on a text file or a dataset without using HuggingFace Trainer.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import math
import os
import time
from itertools import chain
import datasets
import torch
import torch.distributed as dist
from accelerate.utils import set_seed
from context import barrier_context
from datasets import load_dataset
from packaging import version
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AutoConfig,
AutoTokenizer,
GPT2Tokenizer,
OPTForCausalLM,
SchedulerType,
default_data_collator,
get_scheduler,
)
from transformers.utils.versions import require_version
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def get_time_stamp():
torch.cuda.synchronize()
return time.time()
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The configuration name of the dataset to use (via the datasets library).",
)
parser.add_argument("--train_file",
type=str,
default=None,
help="A csv or a json file containing the training data.")
parser.add_argument("--validation_file",
type=str,
default=None,
help="A csv or a json file containing the validation data.")
parser.add_argument(
"--validation_split_percentage",
default=5,
help="The percentage of the train set used as validation set in case there's no validation split",
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument(
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument("--num_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--model_type",
type=str,
default=None,
help="Model type to use if training from scratch.",
choices=MODEL_TYPES,
)
parser.add_argument(
"--block_size",
type=int,
default=None,
help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of"
" this size for training. Default to the model max input length for single sentence inputs (take into"
" account special tokens)."),
)
parser.add_argument(
"--preprocessing_num_workers",
type=int,
default=None,
help="The number of processes to use for the preprocessing.",
)
parser.add_argument("--overwrite_cache",
type=bool,
default=False,
help="Overwrite the cached training and evaluation sets")
parser.add_argument("--no_keep_linebreaks",
action="store_true",
help="Do not keep line breaks when using TXT files.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_model_id",
type=str,
help="The name of the repository to keep in sync with the local `output_dir`.")
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."),
)
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
return args
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB))
def main():
args = parse_args()
disable_existing_loggers()
colossalai.launch_from_torch(config=dict())
logger = get_dist_logger()
is_main_process = dist.get_rank() == 0
if is_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}")
# Handle the repository creation
with barrier_context():
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
logger.info("Start preparing dataset", ranks=[0])
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
)
raw_datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
)
else:
data_files = {}
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
extension = args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
**dataset_args,
)
raw_datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
**dataset_args,
)
logger.info("Dataset is prepared", ranks=[0])
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model and tokenizer
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if args.config_name:
config = AutoConfig.from_pretrained(args.config_name)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(args.model_name_or_path)
else:
config = CONFIG_MAPPING[args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
logger.info("Model config has been created", ranks=[0])
if args.model_name_or_path == 'facebook/opt-13b':
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
else:
print(f'load model from {args.model_name_or_path}')
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0])
if args.init_in_cpu:
init_dev = torch.device('cpu')
else:
init_dev = get_current_device()
# build model
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
# currently, there has a bug in pretrained opt-13b
# we can not import it until huggingface fix it
logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev):
model = OPTForCausalLM(config)
else:
logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
local_files_only=False)
# enable graident checkpointing
model.gradient_checkpointing_enable()
PLACEMENT_POLICY = 'auto'
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager)
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
def tokenize_function(examples):
return tokenizer(examples[text_column_name])
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
if args.block_size is None:
block_size = tokenizer.model_max_length
if block_size > 1024:
logger.warning(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --block_size xxx.")
block_size = 1024
else:
if args.block_size > tokenizer.model_max_length:
logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.")
block_size = min(args.block_size, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i:i + block_size] for i in range(0, total_length, block_size)
] for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc=f"Grouping texts in chunks of {block_size}",
)
train_dataset = lm_datasets["train"]
eval_dataset = lm_datasets["validation"]
# Log a few random samples from the training set:
# for index in random.sample(range(len(train_dataset)), 3):
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# DataLoaders creation:
train_dataloader = get_dataloader(train_dataset,
shuffle=True,
add_sampler=True,
collate_fn=default_data_collator,
batch_size=args.per_device_train_batch_size)
eval_dataloader = DataLoader(eval_dataset,
collate_fn=default_data_collator,
batch_size=args.per_device_eval_batch_size)
logger.info("Dataloaders have been created", ranks=[0])
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
logger.info("***** Running training *****", ranks=[0])
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process)
completed_steps = 0
starting_epoch = 0
global_step = 0
for epoch in range(starting_epoch, args.num_train_epochs):
if completed_steps >= args.max_train_steps:
break
model.train()
for step, batch in enumerate(train_dataloader):
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
loss = outputs['loss']
optimizer.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
completed_steps += 1
global_step += 1
logger.info("Global step {} finished".format(global_step + 1), ranks=[0])
if completed_steps >= args.max_train_steps:
break
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
loss = outputs['loss'].unsqueeze(0)
losses.append(loss)
losses = torch.cat(losses)
losses = losses[:len(eval_dataset)]
try:
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
if args.output_dir is not None:
model_state = model.state_dict()
if is_main_process:
torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
dist.barrier()
# load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
# model.load_state_dict(load_state, strict=False)
logger.info("Training finished", ranks=[0])
if __name__ == "__main__":
main()

View File

@ -0,0 +1,22 @@
set -x
export BS=${1:-16}
export MEMCAP=${2:-0}
export MODEL=${3:-"125m"}
export GPUNUM=${4:-1}
# make directory for logs
mkdir -p ./logs
export MODLE_PATH="facebook/opt-${MODEL}"
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
torchrun \
--nproc_per_node ${GPUNUM} \
--master_port 19198 \
run_clm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--output_dir $PWD \
--mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE_PATH} \
--per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log

View File

@ -0,0 +1,16 @@
## Overview
This example shows how to use ColossalAI to run huggingface GPT training with Gemini and ZeRO DDP.
## GPT
We use the huggingface transformers GPT2 model. The input data is randonly generated.
## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
## Quick Start
You can launch training by using the following bash script
```bash
pip install -r requirements.txt
bash run.sh
```

View File

@ -0,0 +1,3 @@
colossalai >= 0.1.10
torch >= 1.8.1
transformers >= 4.231

View File

@ -0,0 +1 @@
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log

View File

@ -0,0 +1,241 @@
from functools import partial
from time import time
import psutil
import torch
import torch.nn as nn
from packaging import version
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import GPT2Config, GPT2LMHeadModel
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini.",
)
args = parser.parse_args()
return args
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
if param.process_group.tp_world_size() == 1:
param.set_process_group(pg)
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50257,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
## Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def gpt2_medium(checkpoint=False):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
def gpt2_xl(checkpoint=True):
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
def gpt2_10b(checkpoint=True):
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# set process group for all parameters
param.set_process_group(pg)
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # colmn slice
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
def main():
args = parse_args()
BATCH_SIZE = 8
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
disable_existing_loggers()
colossalai.launch_from_torch(config={})
pg = ProcessGroup(tp_degree=args.tp_degree)
logger = get_dist_logger()
logger.info(get_mem_info(), ranks=[0])
# build GPT model
with ColoInitContext(device=get_current_device()):
model = gpt2_medium(checkpoint=True)
numel = sum([p.numel() for p in model.parameters()])
logger.info(f'Model numel: {numel}', ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# build criterion
criterion = GPTLMLoss()
# build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
torch.cuda.synchronize()
model.train()
for n in range(NUM_STEPS):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
optimizer.backward(loss)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
step_time = time() - start
logger.info(
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0])
torch.cuda.synchronize()
if __name__ == '__main__':
main()

View File

@ -1,4 +1,5 @@
# Stable Diffusion with Colossal-AI
# Handson 6: Acceleration of Stable Diffusion
*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*