[tutorial] edited hands-on practices (#1899)

* Add handson to ColossalAI.

* Change names of handsons and edit sequence parallel example.

* Edit wrong folder name

* resolve conflict

* delete readme
pull/1903/head
BoxiangW 2022-11-11 04:08:17 -05:00 committed by GitHub
parent d9bf83e084
commit ca6e75bc28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
121 changed files with 20464 additions and 0 deletions

View File

@ -0,0 +1,27 @@
# Handson 1: Multi-dimensional Parallelism with Colossal-AI
## Install Colossal-AI and other dependencies
```bash
sh install.sh
```
## Prepare Dataset
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
```bash
export DATA=/path/to/data
```
## Run on 2*2 device mesh
Current configuration setting on `config.py` is TP=2, PP=2.
```bash
colossalai run --nproc_per_node 4 train.py --config config.py
```

View File

@ -0,0 +1,36 @@
from colossalai.amp import AMP_TYPE
# hyperparameters
# BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size
BATCH_SIZE = 256
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 10
WARMUP_EPOCHS = 3
# model config
IMG_SIZE = 224
PATCH_SIZE = 16
HIDDEN_SIZE = 512
DEPTH = 4
NUM_HEADS = 4
MLP_RATIO = 2
NUM_CLASSES = 1000
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
# parallel setting
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
parallel = dict(
pipeline=2,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0
# pipeline config
NUM_MICRO_BATCHES = parallel['pipeline']

View File

@ -0,0 +1,4 @@
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install colossalai==0.1.10+torch1.12cu11.3 -f https://release.colossalai.org
pip install titans
colossalai check -i

View File

@ -0,0 +1,116 @@
import os
import colossalai
import torch
from tqdm import tqdm
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp, get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.model.vit.vit import _create_vit_model
from titans.dataloader.cifar10 import build_cifar
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
colossalai.launch_from_torch(config=args.config)
# get logger
logger = get_dist_logger()
logger.info("initialized distributed environment", ranks=[0])
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
use_pipeline = is_using_pp()
# create model
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
patch_size=gpc.config.PATCH_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
depth=gpc.config.DEPTH,
num_heads=gpc.config.NUM_HEADS,
mlp_ratio=gpc.config.MLP_RATIO,
num_classes=10,
init_method='jax',
checkpoint=gpc.config.CHECKPOINT)
if use_pipeline:
pipelinable = PipelinableContext()
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)
# count number of parameters
total_numel = 0
for p in model.parameters():
total_numel += p.numel()
if not gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_stage = 0
else:
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
# create dataloaders
root = os.environ.get('DATA', '../data/cifar10')
train_dataloader, test_dataloader = build_cifar(
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
# create optimizer
optimizer = torch.optim.AdamW(model.parameters(
), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
# initialize
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
data_iter = iter(train_dataloader)
for epoch in range(gpc.config.NUM_EPOCHS):
# training
engine.train()
if gpc.get_global_rank() == 0:
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
progress = tqdm(range(len(train_dataloader)), desc=description)
else:
progress = range(len(train_dataloader))
for _ in progress:
engine.zero_grad()
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,17 @@
# Handson 4: Comparison of Large Batch Training Optimization
## Prepare Dataset
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
```bash
export DATA=/path/to/data
```
## Run on 2*2 device mesh
```bash
colossalai run --nproc_per_node 4 train.py --config config.py
```

View File

@ -0,0 +1,36 @@
from colossalai.amp import AMP_TYPE
# hyperparameters
# BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size
BATCH_SIZE = 512
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 10
WARMUP_EPOCHS = 3
# model config
IMG_SIZE = 224
PATCH_SIZE = 16
HIDDEN_SIZE = 512
DEPTH = 4
NUM_HEADS = 4
MLP_RATIO = 2
NUM_CLASSES = 1000
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
# parallel setting
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
parallel = dict(
pipeline=2,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0
# pipeline config
NUM_MICRO_BATCHES = parallel['pipeline']

View File

@ -0,0 +1,117 @@
import os
import colossalai
import torch
from tqdm import tqdm
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import Lars, Lamb
from colossalai.utils import is_using_pp, get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext
from titans.model.vit.vit import _create_vit_model
from titans.dataloader.cifar10 import build_cifar
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from torch
colossalai.launch_from_torch(config=args.config)
# get logger
logger = get_dist_logger()
logger.info("initialized distributed environment", ranks=[0])
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
use_pipeline = is_using_pp()
# create model
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
patch_size=gpc.config.PATCH_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
depth=gpc.config.DEPTH,
num_heads=gpc.config.NUM_HEADS,
mlp_ratio=gpc.config.MLP_RATIO,
num_classes=10,
init_method='jax',
checkpoint=gpc.config.CHECKPOINT)
if use_pipeline:
pipelinable = PipelinableContext()
with pipelinable:
model = _create_vit_model(**model_kwargs)
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
else:
model = _create_vit_model(**model_kwargs)
# count number of parameters
total_numel = 0
for p in model.parameters():
total_numel += p.numel()
if not gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_stage = 0
else:
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
# create dataloaders
root = os.environ.get('DATA', '../data/cifar10')
train_dataloader, test_dataloader = build_cifar(
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
# create optimizer
optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE,
weight_decay=gpc.config.WEIGHT_DECAY)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
# initialize
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
data_iter = iter(train_dataloader)
for epoch in range(gpc.config.NUM_EPOCHS):
# training
engine.train()
if gpc.get_global_rank() == 0:
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
progress = tqdm(range(len(train_dataloader)), desc=description)
else:
progress = range(len(train_dataloader))
for _ in progress:
engine.zero_grad()
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
if __name__ == '__main__':
main()

View File

@ -0,0 +1 @@
# Handson 5: Fine-tuning and Serving for OPT from Hugging Face

View File

@ -0,0 +1,77 @@
# Overview
This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI.
It supports tensor parallelism, batching and caching.
# How to run
Run OPT-125M:
```shell
python opt_fastapi.py opt-125m
```
It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs.
## Configure
### Configure model
```shell
python opt_fastapi.py <model>
```
Available models: opt-125m, opt-6.7b, opt-30b, opt-175b.
### Configure tensor parallelism
```shell
python opt_fastapi.py <model> --tp <TensorParallelismWorldSize>
```
The `<TensorParallelismWorldSize>` can be an integer in `[1, #GPUs]`. Default `1`.
### Configure checkpoint
```shell
python opt_fastapi.py <model> --checkpoint <CheckpointPath>
```
The `<CheckpointPath>` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded.
### Configure queue
```shell
python opt_fastapi.py <model> --queue_size <QueueSize>
```
The `<QueueSize>` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406).
### Configure bathcing
```shell
python opt_fastapi.py <model> --max_batch_size <MaxBatchSize>
```
The `<MaxBatchSize>` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value.
Note that the batch size is not always equal to `<MaxBatchSize>`, as some consecutive requests may not be batched.
### Configure caching
```shell
python opt_fastapi.py <model> --cache_size <CacheSize> --cache_list_size <CacheListSize>
```
This will cache `<CacheSize>` unique requests. And for each unique request, it cache `<CacheListSize>` different results. A random result will be returned if the cache is hit.
The `<CacheSize>` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `<CacheListSize>` can be an integer in `[1, MAXINT]`.
### Other configurations
```shell
python opt_fastapi.py -h
```
# How to benchmark
```shell
cd benchmark
locust
```
Then open the web interface link which is on your console.
# Pre-process pre-trained weights
## OPT-66B
See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py).
## OPT-175B
See [script/process-opt-175b](./script/process-opt-175b/).

View File

@ -0,0 +1,59 @@
import torch
from typing import List, Deque, Tuple, Hashable, Any
from energonai import BatchManager, SubmitEntry, TaskEntry
class BatchManagerForGeneration(BatchManager):
def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.pad_token_id = pad_token_id
def _left_padding(self, batch_inputs):
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
outputs = {'input_ids': [], 'attention_mask': []}
for inputs in batch_inputs:
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
padding_len = max_len - len(input_ids)
input_ids = [self.pad_token_id] * padding_len + input_ids
attention_mask = [0] * padding_len + attention_mask
outputs['input_ids'].append(input_ids)
outputs['attention_mask'].append(attention_mask)
for k in outputs:
outputs[k] = torch.tensor(outputs[k])
return outputs, max_len
@staticmethod
def _make_batch_key(entry: SubmitEntry) -> tuple:
data = entry.data
return (data['top_k'], data['top_p'], data['temperature'])
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
entry = q.popleft()
uids = [entry.uid]
batch = [entry.data]
while len(batch) < self.max_batch_size:
if len(q) == 0:
break
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
break
if q[0].data['max_tokens'] > entry.data['max_tokens']:
break
e = q.popleft()
batch.append(e.data)
uids.append(e.uid)
inputs, max_len = self._left_padding(batch)
trunc_lens = []
for data in batch:
trunc_lens.append(max_len + data['max_tokens'])
inputs['top_k'] = entry.data['top_k']
inputs['top_p'] = entry.data['top_p']
inputs['temperature'] = entry.data['temperature']
inputs['max_tokens'] = max_len + entry.data['max_tokens']
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
retval = []
for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens):
retval.append((uid, output[:trunc_len]))
return retval

View File

@ -0,0 +1,15 @@
from locust import HttpUser, task
from json import JSONDecodeError
class GenerationUser(HttpUser):
@task
def generate(self):
prompt = 'Question: What is the longest river on the earth? Answer:'
for i in range(4, 9):
data = {'max_tokens': 2**i, 'prompt': prompt}
with self.client.post('/generation', json=data, catch_response=True) as response:
if response.status_code in (200, 406):
response.success()
else:
response.failure('Response wrong')

View File

@ -0,0 +1,64 @@
from collections import OrderedDict
from threading import Lock
from contextlib import contextmanager
from typing import List, Any, Hashable, Dict
class MissCacheError(Exception):
pass
class ListCache:
def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None:
"""Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied.
When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed.
Args:
cache_size (int): Max size for LRU cache.
list_size (int): Value list size.
fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to [].
"""
self.cache_size = cache_size
self.list_size = list_size
self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict()
self.fixed_cache: Dict[Hashable, List[Any]] = {}
for key in fixed_keys:
self.fixed_cache[key] = []
self._lock = Lock()
def get(self, key: Hashable) -> List[Any]:
with self.lock():
if key in self.fixed_cache:
l = self.fixed_cache[key]
if len(l) >= self.list_size:
return l
elif key in self.cache:
self.cache.move_to_end(key)
l = self.cache[key]
if len(l) >= self.list_size:
return l
raise MissCacheError()
def add(self, key: Hashable, value: Any) -> None:
with self.lock():
if key in self.fixed_cache:
l = self.fixed_cache[key]
if len(l) < self.list_size and value not in l:
l.append(value)
elif key in self.cache:
self.cache.move_to_end(key)
l = self.cache[key]
if len(l) < self.list_size and value not in l:
l.append(value)
else:
if len(self.cache) >= self.cache_size:
self.cache.popitem(last=False)
self.cache[key] = [value]
@contextmanager
def lock(self):
try:
self._lock.acquire()
yield
finally:
self._lock.release()

View File

@ -0,0 +1,123 @@
import argparse
import logging
import random
from typing import Optional
import uvicorn
from energonai import QueueFullError, launch_engine
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from transformers import GPT2Tokenizer
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
class GenerationTaskReq(BaseModel):
max_tokens: int = Field(gt=0, le=256, example=64)
prompt: str = Field(
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
app = FastAPI()
@app.post('/generation')
async def generate(data: GenerationTaskReq, request: Request):
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
key = (data.prompt, data.max_tokens)
try:
if cache is None:
raise MissCacheError()
outputs = cache.get(key)
output = random.choice(outputs)
logger.info('Cache hit')
except MissCacheError:
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
inputs['max_tokens'] = data.max_tokens
inputs['top_k'] = data.top_k
inputs['top_p'] = data.top_p
inputs['temperature'] = data.temperature
try:
uid = id(data)
engine.submit(uid, inputs)
output = await engine.wait(uid)
output = tokenizer.decode(output, skip_special_tokens=True)
if cache is not None:
cache.add(key, output)
except QueueFullError as e:
raise HTTPException(status_code=406, detail=e.args[0])
return {'text': output}
@app.on_event("shutdown")
async def shutdown(*_):
engine.shutdown()
server.should_exit = True
server.force_exit = True
await server.shutdown()
def get_model_fn(model_name: str):
model_map = {
'opt-125m': opt_125M,
'opt-6.7b': opt_6B,
'opt-30b': opt_30B,
'opt-175b': opt_175B
}
return model_map[model_name]
def print_args(args: argparse.Namespace):
print('\n==> Args:')
for k, v in args.__dict__.items():
print(f'{k} = {v}')
FIXED_CACHE_KEYS = [
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
parser.add_argument('--tp', type=int, default=1)
parser.add_argument('--master_host', default='localhost')
parser.add_argument('--master_port', type=int, default=19990)
parser.add_argument('--rpc_port', type=int, default=19980)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--pipe_size', type=int, default=1)
parser.add_argument('--queue_size', type=int, default=0)
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--checkpoint', default=None)
parser.add_argument('--cache_size', type=int, default=0)
parser.add_argument('--cache_list_size', type=int, default=1)
args = parser.parse_args()
print_args(args)
model_kwargs = {}
if args.checkpoint is not None:
model_kwargs['checkpoint'] = args.checkpoint
logger = logging.getLogger(__name__)
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
if args.cache_size > 0:
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
else:
cache = None
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
pad_token_id=tokenizer.pad_token_id),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs)
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
server = uvicorn.Server(config=config)
server.run()

View File

@ -0,0 +1,122 @@
import logging
import argparse
import random
from torch import Tensor
from pydantic import BaseModel, Field
from typing import Optional
from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B
from transformers import GPT2Tokenizer
from energonai import launch_engine, QueueFullError
from sanic import Sanic
from sanic.request import Request
from sanic.response import json
from sanic_ext import validate, openapi
from batch import BatchManagerForGeneration
from cache import ListCache, MissCacheError
class GenerationTaskReq(BaseModel):
max_tokens: int = Field(gt=0, le=256, example=64)
prompt: str = Field(
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
top_k: Optional[int] = Field(default=None, gt=0, example=50)
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
app = Sanic('opt')
@app.post('/generation')
@openapi.body(GenerationTaskReq)
@validate(json=GenerationTaskReq)
async def generate(request: Request, body: GenerationTaskReq):
logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}')
key = (body.prompt, body.max_tokens)
try:
if cache is None:
raise MissCacheError()
outputs = cache.get(key)
output = random.choice(outputs)
logger.info('Cache hit')
except MissCacheError:
inputs = tokenizer(body.prompt, truncation=True, max_length=512)
inputs['max_tokens'] = body.max_tokens
inputs['top_k'] = body.top_k
inputs['top_p'] = body.top_p
inputs['temperature'] = body.temperature
try:
uid = id(body)
engine.submit(uid, inputs)
output = await engine.wait(uid)
assert isinstance(output, Tensor)
output = tokenizer.decode(output, skip_special_tokens=True)
if cache is not None:
cache.add(key, output)
except QueueFullError as e:
return json({'detail': e.args[0]}, status=406)
return json({'text': output})
@app.after_server_stop
def shutdown(*_):
engine.shutdown()
def get_model_fn(model_name: str):
model_map = {
'opt-125m': opt_125M,
'opt-6.7b': opt_6B,
'opt-30b': opt_30B,
'opt-175b': opt_175B
}
return model_map[model_name]
def print_args(args: argparse.Namespace):
print('\n==> Args:')
for k, v in args.__dict__.items():
print(f'{k} = {v}')
FIXED_CACHE_KEYS = [
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
parser.add_argument('--tp', type=int, default=1)
parser.add_argument('--master_host', default='localhost')
parser.add_argument('--master_port', type=int, default=19990)
parser.add_argument('--rpc_port', type=int, default=19980)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--pipe_size', type=int, default=1)
parser.add_argument('--queue_size', type=int, default=0)
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--checkpoint', default=None)
parser.add_argument('--cache_size', type=int, default=0)
parser.add_argument('--cache_list_size', type=int, default=1)
args = parser.parse_args()
print_args(args)
model_kwargs = {}
if args.checkpoint is not None:
model_kwargs['checkpoint'] = args.checkpoint
logger = logging.getLogger(__name__)
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
if args.cache_size > 0:
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
else:
cache = None
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
pad_token_id=tokenizer.pad_token_id),
pipe_size=args.pipe_size,
queue_size=args.queue_size,
**model_kwargs)
app.run(args.http_host, args.http_port)

View File

@ -0,0 +1,8 @@
fastapi==0.85.1
locust==2.11.0
pydantic==1.10.2
sanic==22.9.0
sanic_ext==22.9.0
torch>=1.10.0
transformers==4.23.1
uvicorn==0.19.0

View File

@ -0,0 +1,46 @@
# Process OPT-175B weights
You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this.
First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`.
Then, `cd metaseq`.
To consolidate checkpoints to eliminate FSDP:
```shell
bash metaseq/scripts/reshard_mp_launch_no_slurm.sh <directory_where_all_the_shards_are>/checkpoint_last <output_dir>/ 8 1
```
You will get 8 files in `<output_dir>`, and you should have the following checksums:
```
7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt
c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt
45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt
abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt
05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt
d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt
fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt
2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt
```
Copy `flat-meta.json` to `<output_dir>`.
Then cd to this dir, and we unflatten parameters.
```shell
bash unflat.sh <output_dir>/ <new_output_dir>/
```
Finally, you will get 8 files in `<new_output_dir>` with following checksums:
```
6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt
58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt
69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt
002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt
6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt
93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
```

View File

@ -0,0 +1,55 @@
import argparse
import json
import os
import re
from collections import defaultdict
import numpy as np
import torch
def load_json(path: str):
with open(path) as f:
return json.load(f)
def parse_shape_info(flat_dir: str):
data = load_json(os.path.join(flat_dir, 'shape.json'))
flat_info = defaultdict(lambda: defaultdict(list))
for k, shape in data.items():
matched = re.match(r'decoder.layers.\d+', k)
if matched is None:
flat_key = 'flat_param_0'
else:
flat_key = f'{matched[0]}.flat_param_0'
flat_info[flat_key]['names'].append(k)
flat_info[flat_key]['shapes'].append(shape)
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
return flat_info
def convert(flat_dir: str, output_dir: str, part: int):
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
flat_sd = torch.load(flat_path)
print(f'Loaded flat state dict from {flat_path}')
output_sd = {}
for flat_key, param_meta in flat_meta.items():
flat_param = flat_sd['model'][flat_key]
assert sum(param_meta['numels']) == flat_param.numel(
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
output_sd[name] = param.view(shape)
torch.save(output_sd, output_path)
print(f'Saved unflat state dict to {output_path}')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('flat_dir')
parser.add_argument('output_dir')
parser.add_argument('part', type=int)
args = parser.parse_args()
convert(args.flat_dir, args.output_dir, args.part)

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,7 @@
#!/usr/bin/env sh
for i in $(seq 0 7); do
python convert_ckpt.py $1 $2 ${i} &
done
wait $(jobs -p)

View File

@ -0,0 +1,55 @@
import os
import torch
from multiprocessing import Pool
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
# you can use whether wget or git lfs
path = "/path/to/your/ckpt"
new_path = "/path/to/the/processed/ckpt/"
assert os.path.isdir(path)
files = []
for filename in os.listdir(path):
filepath = os.path.join(path, filename)
if os.path.isfile(filepath):
files.append(filepath)
with Pool(14) as pool:
ckpts = pool.map(torch.load, files)
restored = {}
for ckpt in ckpts:
for k,v in ckpt.items():
if(k[0] == 'm'):
k = k[6:]
if(k == "lm_head.weight"):
k = "head.dense.weight"
if(k == "decoder.final_layer_norm.weight"):
k = "decoder.layer_norm.weight"
if(k == "decoder.final_layer_norm.bias"):
k = "decoder.layer_norm.bias"
restored[k] = v
restored["decoder.version"] = "0.0"
split_num = len(restored.keys()) // 60
count = 0
file_count = 1
tmp = {}
for k,v in restored.items():
print(k)
tmp[k] = v
count = count + 1
if(count == split_num):
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))
file_count = file_count + 1
count = 0
tmp = {}
filename = str(file_count) + "-restored.pt"
torch.save(tmp, os.path.join(new_path, filename))

View File

@ -0,0 +1,53 @@
<!---
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Train OPT model with Colossal-AI
## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
## Quick Start
You can launch training by using the following bash script
```bash
bash ./run_clm.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
```
- batch-size-per-gpu: number of samples fed to each GPU, default is 16
- mem-cap: limit memory usage within a value in GB, default is 0 (no limit)
- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
- gpu-num: the number of GPUs to use, default is 1.
## Remarkable Performance
On a single GPU, Colossal-AIs automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT.png" width=1000/>
</p>
Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI!
More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d),
and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon.

View File

@ -0,0 +1,21 @@
export BS=16
export MEMCAP=0
export MODEL="6.7b"
export GPUNUM=1
for MODEL in "6.7b" "13b" "1.3b"
do
for GPUNUM in 8 1
do
for BS in 16 24 32 8
do
for MEMCAP in 0 40
do
pkill -9 torchrun
pkill -9 python
bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM
done
done
done
done

View File

@ -0,0 +1,6 @@
from colossalai.zero.shard_utils import TensorShardStrategy
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
tensor_placement_policy="auto",
reuse_fp16_shard=True),
optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384))

View File

@ -0,0 +1,32 @@
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
class barrier_context():
"""
This context manager is used to allow one process to execute while blocking all
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
Usage:
with barrier_context():
dataset = CIFAR10(root='./data', download=True)
"""
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
# the class name is lowercase by convention
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
self.should_block = current_rank != executor_rank
self.group = gpc.get_group(parallel_mode=parallel_mode)
def __enter__(self):
if self.should_block:
dist.barrier(group=self.group)
def __exit__(self, exc_type, exc_value, exc_traceback):
if not self.should_block:
dist.barrier(group=self.group)

View File

@ -0,0 +1,6 @@
colossalai
torch >= 1.8.1
datasets >= 1.8.0
sentencepiece != 0.1.92
protobuf
accelerate == 0.13.2

View File

@ -0,0 +1,596 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
on a text file or a dataset without using HuggingFace Trainer.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import math
import os
import time
from itertools import chain
import datasets
import torch
import torch.distributed as dist
from accelerate.utils import set_seed
from context import barrier_context
from datasets import load_dataset
from packaging import version
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AutoConfig,
AutoTokenizer,
GPT2Tokenizer,
OPTForCausalLM,
SchedulerType,
default_data_collator,
get_scheduler,
)
from transformers.utils.versions import require_version
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def get_time_stamp():
torch.cuda.synchronize()
return time.time()
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The configuration name of the dataset to use (via the datasets library).",
)
parser.add_argument("--train_file",
type=str,
default=None,
help="A csv or a json file containing the training data.")
parser.add_argument("--validation_file",
type=str,
default=None,
help="A csv or a json file containing the validation data.")
parser.add_argument(
"--validation_split_percentage",
default=5,
help="The percentage of the train set used as validation set in case there's no validation split",
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument(
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument("--num_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--model_type",
type=str,
default=None,
help="Model type to use if training from scratch.",
choices=MODEL_TYPES,
)
parser.add_argument(
"--block_size",
type=int,
default=None,
help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of"
" this size for training. Default to the model max input length for single sentence inputs (take into"
" account special tokens)."),
)
parser.add_argument(
"--preprocessing_num_workers",
type=int,
default=None,
help="The number of processes to use for the preprocessing.",
)
parser.add_argument("--overwrite_cache",
type=bool,
default=False,
help="Overwrite the cached training and evaluation sets")
parser.add_argument("--no_keep_linebreaks",
action="store_true",
help="Do not keep line breaks when using TXT files.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_model_id",
type=str,
help="The name of the repository to keep in sync with the local `output_dir`.")
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."),
)
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if args.train_file is not None:
extension = args.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
if args.validation_file is not None:
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
return args
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB))
def main():
args = parse_args()
disable_existing_loggers()
colossalai.launch_from_torch(config=dict())
logger = get_dist_logger()
is_main_process = dist.get_rank() == 0
if is_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}")
# Handle the repository creation
with barrier_context():
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
logger.info("Start preparing dataset", ranks=[0])
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[:{args.validation_split_percentage}%]",
)
raw_datasets["train"] = load_dataset(
args.dataset_name,
args.dataset_config_name,
split=f"train[{args.validation_split_percentage}%:]",
)
else:
data_files = {}
dataset_args = {}
if args.train_file is not None:
data_files["train"] = args.train_file
if args.validation_file is not None:
data_files["validation"] = args.validation_file
extension = args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{args.validation_split_percentage}%]",
**dataset_args,
)
raw_datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{args.validation_split_percentage}%:]",
**dataset_args,
)
logger.info("Dataset is prepared", ranks=[0])
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model and tokenizer
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if args.config_name:
config = AutoConfig.from_pretrained(args.config_name)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(args.model_name_or_path)
else:
config = CONFIG_MAPPING[args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
logger.info("Model config has been created", ranks=[0])
if args.model_name_or_path == 'facebook/opt-13b':
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
else:
print(f'load model from {args.model_name_or_path}')
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0])
if args.init_in_cpu:
init_dev = torch.device('cpu')
else:
init_dev = get_current_device()
# build model
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
# currently, there has a bug in pretrained opt-13b
# we can not import it until huggingface fix it
logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev):
model = OPTForCausalLM(config)
else:
logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
local_files_only=False)
# enable graident checkpointing
model.gradient_checkpointing_enable()
PLACEMENT_POLICY = 'auto'
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager)
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
def tokenize_function(examples):
return tokenizer(examples[text_column_name])
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not args.overwrite_cache,
desc="Running tokenizer on dataset",
)
if args.block_size is None:
block_size = tokenizer.model_max_length
if block_size > 1024:
logger.warning(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --block_size xxx.")
block_size = 1024
else:
if args.block_size > tokenizer.model_max_length:
logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.")
block_size = min(args.block_size, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i:i + block_size] for i in range(0, total_length, block_size)
] for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=args.preprocessing_num_workers,
load_from_cache_file=not args.overwrite_cache,
desc=f"Grouping texts in chunks of {block_size}",
)
train_dataset = lm_datasets["train"]
eval_dataset = lm_datasets["validation"]
# Log a few random samples from the training set:
# for index in random.sample(range(len(train_dataset)), 3):
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# DataLoaders creation:
train_dataloader = get_dataloader(train_dataset,
shuffle=True,
add_sampler=True,
collate_fn=default_data_collator,
batch_size=args.per_device_train_batch_size)
eval_dataloader = DataLoader(eval_dataset,
collate_fn=default_data_collator,
batch_size=args.per_device_eval_batch_size)
logger.info("Dataloaders have been created", ranks=[0])
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
logger.info("***** Running training *****", ranks=[0])
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process)
completed_steps = 0
starting_epoch = 0
global_step = 0
for epoch in range(starting_epoch, args.num_train_epochs):
if completed_steps >= args.max_train_steps:
break
model.train()
for step, batch in enumerate(train_dataloader):
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
loss = outputs['loss']
optimizer.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
completed_steps += 1
global_step += 1
logger.info("Global step {} finished".format(global_step + 1), ranks=[0])
if completed_steps >= args.max_train_steps:
break
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
loss = outputs['loss'].unsqueeze(0)
losses.append(loss)
losses = torch.cat(losses)
losses = losses[:len(eval_dataset)]
try:
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
if args.output_dir is not None:
model_state = model.state_dict()
if is_main_process:
torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
dist.barrier()
# load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
# model.load_state_dict(load_state, strict=False)
logger.info("Training finished", ranks=[0])
if __name__ == "__main__":
main()

View File

@ -0,0 +1,22 @@
set -x
export BS=${1:-16}
export MEMCAP=${2:-0}
export MODEL=${3:-"125m"}
export GPUNUM=${4:-1}
# make directory for logs
mkdir -p ./logs
export MODLE_PATH="facebook/opt-${MODEL}"
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
torchrun \
--nproc_per_node ${GPUNUM} \
--master_port 19198 \
run_clm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--output_dir $PWD \
--mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE_PATH} \
--per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log

View File

@ -0,0 +1,16 @@
## Overview
This example shows how to use ColossalAI to run huggingface GPT training with Gemini and ZeRO DDP.
## GPT
We use the huggingface transformers GPT2 model. The input data is randonly generated.
## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
## Quick Start
You can launch training by using the following bash script
```bash
pip install -r requirements.txt
bash run.sh
```

View File

@ -0,0 +1,3 @@
colossalai >= 0.1.10
torch >= 1.8.1
transformers >= 4.231

View File

@ -0,0 +1 @@
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log

View File

@ -0,0 +1,241 @@
from functools import partial
from time import time
import psutil
import torch
import torch.nn as nn
from packaging import version
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import GPT2Config, GPT2LMHeadModel
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini.",
)
args = parser.parse_args()
return args
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
if param.process_group.tp_world_size() == 1:
param.set_process_group(pg)
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
class GPTLMModel(nn.Module):
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50257,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
## Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def gpt2_medium(checkpoint=False):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
def gpt2_xl(checkpoint=True):
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
def gpt2_10b(checkpoint=True):
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# set process group for all parameters
param.set_process_group(pg)
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # colmn slice
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
def main():
args = parse_args()
BATCH_SIZE = 8
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
disable_existing_loggers()
colossalai.launch_from_torch(config={})
pg = ProcessGroup(tp_degree=args.tp_degree)
logger = get_dist_logger()
logger.info(get_mem_info(), ranks=[0])
# build GPT model
with ColoInitContext(device=get_current_device()):
model = gpt2_medium(checkpoint=True)
numel = sum([p.numel() for p in model.parameters()])
logger.info(f'Model numel: {numel}', ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# build criterion
criterion = GPTLMLoss()
# build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
torch.cuda.synchronize()
model.train()
for n in range(NUM_STEPS):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
optimizer.backward(loss)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
step_time = time() - start
logger.info(
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0])
torch.cuda.synchronize()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,143 @@
# Handson 2: Sequence Parallelism with BERT
In this example, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate
activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length.
Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
## How to Prepare WikiPedia Dataset
First, let's prepare the WikiPedia dataset from scratch. To generate a preprocessed dataset, we need four items:
1. raw WikiPedia dataset
2. wikipedia extractor (extract data from the raw dataset)
3. vocabulary file
4. preprocessing scripts (generate final data from extracted data)
For the preprocessing script, we thank Megatron-LM for providing a preprocessing script to generate the corpus file.
```python
# download raw data
mkdir data && cd ./data
wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
# install wiki extractor
git clone https://github.com/FrankLeeeee/wikiextractor.git
pip install ./wikiextractor
# extractmodule
wikiextractor --json enwiki-latest-pages-articles.xml.bz2
cat text/*/* > ./corpus.json
cd ..
# download vocab file
mkdir vocab && cd ./vocab
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt
cd ..
# preprocess some data
git clone https://github.com/NVIDIA/Megatron-LM.git
cd ./Megatron-LM
python tools/preprocess_data.py \
--input ../data/corpus.json \
--output-prefix my-bert \
--vocab ../vocab/bert-large-uncased-vocab.txt \
--dataset-impl mmap \
--tokenizer-type BertWordPieceLowerCase \
--split-sentences \
--workers 24
```
After running the preprocessing scripts, you will obtain two files:
1. my-bert_text_sentence.bin
2. my-bert_text_sentence.idx
If you happen to encouter `index out of range` problem when running Megatron's script,
this is probably because that a sentence starts with a punctuation and cannot be tokenized. A work-around is to update `Encoder.encode` method with the code below:
```python
class Encoder(object):
def __init__(self, args):
...
def initializer(self):
...
def encode(self, json_line):
data = json.loads(json_line)
ids = {}
for key in self.args.json_keys:
text = data[key]
doc_ids = []
# lsg: avoid sentences which start with a punctuation
# as it cannot be tokenized by splitter
if len(text) > 0 and text[0] in string.punctuation:
text = text[1:]
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
ids[key] = doc_ids
return ids, len(json_line)
```
## How to Train with Sequence Parallelism
We provided `train.py` for you to execute training. Before invoking the script, there are several
steps to perform.
### Step 1. Set data path and vocab path
At the top of `config.py`, you can see two global variables `DATA_PATH` and `VOCAB_FILE_PATH`.
```python
DATA_PATH = <data-path>
VOCAB_FILE_PATH = <vocab-path>
```
`DATA_PATH` refers to the path to the data file generated by Megatron's script. For example, in the section above, you should get two data files (my-bert_text_sentence.bin and my-bert_text_sentence.idx). You just need to `DATA_PATH` to the path to the bin file without the file extension.
For example, if your my-bert_text_sentence.bin is /home/Megatron-LM/my-bert_text_sentence.bin, then you should set
```python
DATA_PATH = '/home/Megatron-LM/my-bert_text_sentence'
```
The `VOCAB_FILE_PATH` refers to the path to the vocabulary downloaded when you prepare the dataset
(e.g. bert-large-uncased-vocab.txt).
### Step 3. Make Dataset Helper
Build BERT dataset helper. Requirements are `CUDA`, `g++`, `pybind11` and `make`.
```python
cd ./data/datasets
make
```
### Step 3. Configure your parameters
In the `config.py` provided, a set of parameters are defined including training scheme, model, etc.
You can also modify the ColossalAI setting. For example, if you wish to parallelize over the
sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=<num_of_pipeline_stages>`.
### Step 4. Invoke parallel training
Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your
machine setting.
- If you are using a single machine with multiple GPUs, PyTorch launch utility can easily let you
start your script. A sample command is like below:
```bash
python -m torch.distributed.launch --nproc_per_node <num_gpus_on_this_machine> --master_addr localhost --master_port 29500 train.py
```
- If you are using multiple machines with multiple GPUs, we suggest that you refer to `colossalai
launch_from_slurm` or `colossalai.launch_from_openmpi` as it is easier to use SLURM and OpenMPI
to start multiple processes over multiple nodes. If you have your own launcher, you can fall back
to the default `colossalai.launch` function.

View File

@ -0,0 +1,40 @@
from colossalai.amp import AMP_TYPE
DATA_PATH = ''
VOCAB_FILE_PATH = ''
# hyper-parameters
TRAIN_ITERS = 1000000
DECAY_ITERS = 990000
WARMUP_FRACTION = 0.01
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
EVAL_ITERS = 10
EVAL_INTERVAL = 10
LR = 0.0001
MIN_LR = 1e-05
WEIGHT_DECAY = 0.01
SEQ_LENGTH = 512
# BERT config
DEPTH = 12
NUM_ATTENTION_HEADS = 12
HIDDEN_SIZE = 768
# model config
ADD_BINARY_HEAD = False
# random seed
SEED = 1234
# pipeline config
# only enabled when pipeline > 1
NUM_MICRO_BATCHES = 4
# colossalai config
parallel = dict(pipeline=1, tensor=dict(size=4, mode='sequence'))
fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True)
clip_grad_norm = 1.0
gradient_handler = [dict(type='SequenceParallelGradientHandler')]

View File

@ -0,0 +1,102 @@
from colossalai.context.parallel_context import ParallelContext
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.context import ParallelMode
from .datasets.data_samplers import build_pretraining_data_loader
from .datasets.builder import build_train_valid_test_datasets
import torch
def cyclic_iter(iter):
while True:
for x in iter:
yield x
def build_train_valid_test_data_iterators(train_iters,
global_batch_size,
eval_interval,
eval_iters,
dataloader_type='single',
**kwargs
):
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
logger = get_dist_logger()
logger.info('> building train, validation, and test datasets ...', ranks=[0])
# Backward compatibility, assume fixed batch size.
# if iteration > 0 and consumed_train_samples == 0:
# assert train_samples is None, \
# 'only backward compatibility support for iteration-based training'
# consumed_train_samples = iteration * global_batch_size
# if iteration > 0 and consumed_valid_samples == 0:
# if train_samples is None:
# consumed_valid_samples = (iteration // eval_interval) * \
# eval_iters * global_batch_size
# Data loader only on rank 0 of each model parallel group.
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# Number of train/valid/test samples.
train_samples = train_iters * global_batch_size
eval_iters_ = (train_iters // eval_interval + 1) * eval_iters
test_iters = eval_iters
train_val_test_num_samples = [train_samples,
eval_iters_ * global_batch_size,
test_iters * global_batch_size]
logger.info(' > datasets target sizes (minimum size):')
logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0])
logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0])
logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0])
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
train_valid_test_num_samples=train_val_test_num_samples, **kwargs)
# Build dataloaders.
dp_size = gpc.get_world_size(ParallelMode.DATA)
train_dataloader = build_pretraining_data_loader(
train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size)
valid_dataloader = build_pretraining_data_loader(
valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size)
test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size)
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and train_iters > 0
do_valid = valid_dataloader is not None and eval_iters > 0
do_test = test_dataloader is not None and eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
# Build iterators.
dl_type = dataloader_type
assert dl_type in ['single', 'cyclic']
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(train_dataloader))
else:
train_data_iterator = None
if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(valid_dataloader))
else:
valid_data_iterator = None
if test_dataloader is not None:
test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(test_dataloader))
else:
test_data_iterator = None
return train_data_iterator, valid_data_iterator, test_data_iterator

View File

@ -0,0 +1,165 @@
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
import torch
_MAX_DATA_DIM = 5
def _build_key_size_numel_dictionaries(keys, data):
"""Build the size on rank 0 and broadcast."""
max_dim = _MAX_DATA_DIM
sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero.
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
offset = 0
for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
size = data[key].size()
for i, s in enumerate(size):
sizes[i + offset] = s
offset += max_dim
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
key_size = {}
key_numel = {}
total_numel = 0
offset = 0
for key in keys:
i = 0
size = []
numel = 1
while sizes_cpu[offset + i] > 0:
this_size = sizes_cpu[offset + i]
size.append(this_size)
numel *= this_size
i += 1
key_size[key] = size
key_numel[key] = numel
total_numel += numel
offset += max_dim
return key_size, key_numel, total_numel
def broadcast_data(keys, data, datatype):
"""Broadcast data from rank zero of each model parallel group to the
members of the same model parallel group.
Arguments:
keys: list of keys in the data dictionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
with keys.
"""
# Build (key, size) and (key, number of elements) dictionaries along
# with the total number of elements on all ranks.
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
data)
# Pack on rank zero.
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# Check that all keys have the same data type.
# Flatten the data associated with the keys
flatten_data = torch.cat(
[data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
else:
flatten_data = torch.empty(total_numel,
device=torch.cuda.current_device(),
dtype=datatype)
# Broadcast
torch.distributed.broadcast(flatten_data,
gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
# Unpack
output = {}
offset = 0
for key in keys:
size = key_size[key]
numel = key_numel[key]
output[key] = flatten_data.narrow(0, offset, numel).view(size)
offset += numel
return output
def get_batch(data_iterator):
"""Build the batch."""
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['labels'].long()
padding_mask = data_b['padding_mask'].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def get_batch_for_sequence_parallel(data_iterator):
"""Build the batch."""
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
# unpack
data_b = broadcast_data(keys, data, datatype)
# # get tensor parallel local rank
global_rank = torch.distributed.get_rank()
local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR)
local_rank = global_rank % local_world_size
seq_length = data_b['text'].size(1)
sub_seq_length = seq_length // local_world_size
sub_seq_start = local_rank * sub_seq_length
sub_seq_end = (local_rank+1) * sub_seq_length
#
# # Unpack.
tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
padding_mask = data_b['padding_mask'].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
class SequenceParallelDataIterator:
def __init__(self, data_iter):
self.data_iter = data_iter
def __iter__(self):
return self.data_iter
def __next__(self):
return get_batch_for_sequence_parallel(self.data_iter)

View File

@ -0,0 +1,9 @@
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = helpers
LIBEXT = $(shell python3-config --extension-suffix)
default: $(LIBNAME)$(LIBEXT)
%$(LIBEXT): %.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@

View File

@ -0,0 +1 @@
from . import indexed_dataset

View File

@ -0,0 +1,225 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT Style dataset."""
from colossalai.logging import get_dist_logger
import numpy as np
import torch
from torch.utils.data import Dataset
from ..tokenizer import get_tokenizer
from .dataset_utils import (get_a_and_b_segments, truncate_segments, create_tokens_and_tokentypes,
create_masked_lm_predictions, pad_and_convert_to_numpy)
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
import time
import os
from . import helpers
class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length,
short_seq_prob, seed, binary_head):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping_(
self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 3, # account for added tokens,
short_seq_prob,
self.seed,
self.name,
self.binary_head)
# Vocab stuff.
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# We % 2**32 since numpy requires the seed to be between 0 and 2**32 - 1
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
return build_training_sample(
sample,
seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id,
self.sep_id,
self.mask_id,
self.pad_id,
self.masked_lm_prob,
np_rng,
self.binary_head)
def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob,
seed, name, binary_head):
logger = get_dist_logger()
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0])
# First compile and then import.
samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs,
max_num_samples, max_seq_length, short_seq_prob, seed, verbose,
2 if binary_head else 1)
logger.info('\n > done building samples index maping', ranks=[0])
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0])
# Make sure all the ranks have built the mapping
logger.info('\n > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(time.time() - start_time),
ranks=[0])
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA))
if gpc.is_initialized(ParallelMode.PIPELINE):
torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE))
assert counts[0].item() == (torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE)))
# Load indexed dataset.
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) +
'\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) +
'\n total number of samples: {}'.format(samples_mapping.shape[0]),
ranks=[0])
return samples_mapping
def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id,
sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head):
"""Build training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
if binary_head:
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id)
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels,
_) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id,
mask_id, max_predictions_per_seq, np_rng)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)
}
return train_sample

View File

@ -0,0 +1,62 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blendable dataset."""
import time
import numpy as np
import torch
class BlendableDataset(torch.utils.data.Dataset):
def __init__(self, datasets, weights):
self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)
self.size = 0
for dataset in self.datasets:
self.size += len(dataset)
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
# Build indices.
start_time = time.time()
assert num_datasets < 255
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
from . import helpers
helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
print('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))
def __len__(self):
return self.size
def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx]

View File

@ -0,0 +1,152 @@
from .blendable_dataset import BlendableDataset
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_
from .bert_dataset import BertDataset
from colossalai.logging import get_dist_logger
DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5 = 't5'
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is designed to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
logger = get_dist_logger()
# Print stats about the splits.
logger.info('\n > dataset split:', ranks=[0])
def print_split_stats(name, index):
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
logger.info('\n {}:'.format(name) +
'\n document indices in [{}, {}) total of {} documents'.format(
splits[index], splits[index + 1],
splits[index + 1] - splits[index]) +
'\n sentence indices in [{}, {}) total of {} sentences'.format(
start_index, end_index,
end_index - start_index),
ranks=[0])
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
seed=seed,
)
if dataset_type != DSET_TYPE_BERT:
raise NotImplementedError("Only BERT dataset is supported")
else:
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
binary_head=binary_head,
**kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
skip_warmup,
binary_head,
dataset_type=dataset_type)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)

View File

@ -0,0 +1,153 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataloaders."""
import torch
import random
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0):
"""Build dataloader given an input dataset."""
if dataset is None:
return None
# Megatron sampler
if dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
elif dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA),
data_parallel_size=gpc.get_world_size(ParallelMode.DATA))
else:
raise Exception('{} dataloader type is not supported.'.format(dataloader_type))
# Torch dataloader.
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True)
class MegatronPretrainingSampler:
def __init__(self,
total_samples,
consumed_samples,
micro_batch_size,
data_parallel_rank,
data_parallel_size,
drop_last=True):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []

View File

@ -0,0 +1,592 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Most of the code here has been copied from:
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import math
import time
import collections
from colossalai.logging import get_dist_logger
import numpy as np
from .blendable_dataset import BlendableDataset
from .indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_STD = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
def get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0]*num_datasets
prefixes = [0]*num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2*i])
prefixes[i] = (data_prefix[2*i+1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples])
return prefixes, weights, datasets_train_valid_test_num_samples
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
import os
import subprocess
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(['make', '-C', path])
if ret.returncode != 0:
print("Making C++ dataset helpers module failed, exiting.")
import sys
sys.exit(1)
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
# Number of sentences in the sample.
n_sentences = len(sample)
# Make sure we always have two sentences.
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
# First part:
# `a_end` is how many sentences go into the `A`.
a_end = 1
if n_sentences >= 3:
# Note that randin in numpy is exclusive.
a_end = np_rng.randint(1, n_sentences)
tokens_a = []
for j in range(a_end):
tokens_a.extend(sample[j])
# Second part:
tokens_b = []
for j in range(a_end, n_sentences):
tokens_b.extend(sample[j])
# Random next:
is_next_random = False
if np_rng.random() < 0.5:
is_next_random = True
tokens_a, tokens_b = tokens_b, tokens_a
return tokens_a, tokens_b, is_next_random
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens)
assert len_a > 0
if len_a + len_b <= max_num_tokens:
return False
while len_a + len_b > max_num_tokens:
if len_a > len_b:
len_a -= 1
tokens = tokens_a
else:
len_b -= 1
tokens = tokens_b
if np_rng.random() < 0.5:
del tokens[0]
else:
tokens.pop()
return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
tokens = []
tokentypes = []
# [CLS].
tokens.append(cls_id)
tokentypes.append(0)
# Segment A.
for token in tokens_a:
tokens.append(token)
tokentypes.append(0)
# [SEP].
tokens.append(sep_id)
tokentypes.append(0)
# Segment B.
for token in tokens_b:
tokens.append(token)
tokentypes.append(1)
if tokens_b:
# [SEP].
tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def is_start_piece(piece):
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return not piece.startswith("##")
def create_masked_lm_predictions(tokens,
vocab_id_list, vocab_id_to_token_dict,
masked_lm_prob,
cls_id, sep_id, mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
do_permutation=False):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes = []
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens)
for (i, token) in enumerate(tokens):
if token == cls_id or token == sep_id:
token_boundary[i] = 1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
not is_start_piece(vocab_id_to_token_dict[token])):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens = list(tokens)
masked_lm_positions = []
masked_lm_labels = []
if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions,
masked_lm_labels, token_boundary)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
# Note(mingdachen):
# By default, we set the probabilities to favor shorter ngram sequences.
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
pvals = 1. / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx + n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
masked_lms = []
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(select_indexes) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes or index in select_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
select_indexes.add(index)
assert len(select_indexes) <= num_to_predict
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
skip_warmup,
binary_head,
dataset_type=dataset_type)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
logger = get_dist_logger()
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is designed to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
logger.info('\n > dataset split:')
def print_split_stats(name, index):
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
logger.info('\n {}:'.format(name) +
'\n document indices in [{}, {}) total of {} documents'.format(
splits[index],
splits[index + 1],
splits[index + 1] - splits[index]) +
'\n sentence indices in [{}, {}) total of {} sentences'.format(
start_index,
end_index,
end_index - start_index),
ranks=[0])
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
from .bert_dataset import BertDataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
seed=seed,
binary_head=binary_head
)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
dataset = ICTDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.use_one_sent_docs,
**kwargs
)
else:
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
**kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
logger = get_dist_logger()
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
logger.info('\n > building dataset index ...', ranks=[0])
logger.info('\n > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time), ranks=[0])
logger.info('\n > indexed dataset stats:' +
'\n number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1) +
'\n number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]),
ranks=[0]
)
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index

View File

@ -0,0 +1,717 @@
/*
coding=utf-8
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/* Helper methods for fast index mapping builds */
#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <stdexcept>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <random>
namespace py = pybind11;
using namespace std;
const int32_t LONG_SENTENCE_LEN = 512;
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
py::array_t<int64_t>& dataset_sample_index,
const py::array_t<double>& weights,
const int32_t num_datasets,
const int64_t size, const bool verbose) {
/* Given multiple datasets and a weighting array, build samples
such that it follows those wieghts.*/
if (verbose) {
std::cout << "> building indices for blendable datasets ..." << std::endl;
}
// Get the pointer access without the checks.
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
auto weights_ptr = weights.unchecked<1>();
// Initialize buffer for number of samples used for each dataset.
int64_t current_samples[num_datasets];
for(int64_t i = 0; i < num_datasets; ++i) {
current_samples[i] = 0;
}
// For each sample:
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
// Determine where the max error in sampling is happening.
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
int64_t max_error_index = 0;
double max_error = weights_ptr[0] * sample_idx_double -
static_cast<double>(current_samples[0]);
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
double error = weights_ptr[dataset_idx] * sample_idx_double -
static_cast<double>(current_samples[dataset_idx]);
if (error > max_error) {
max_error = error;
max_error_index = dataset_idx;
}
}
// Populate the indices.
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
// Update the total samples.
current_samples[max_error_index] += 1;
}
// print info
if (verbose) {
std::cout << " > sample ratios:" << std::endl;
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
static_cast<double>(size);
std::cout << " dataset " << dataset_idx << ", input: " <<
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
}
}
}
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch) {
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
starting offset in that document.*/
// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);
// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();
// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
cout << " using:" << endl << std::flush;
cout << " number of documents: " <<
doc_idx_.shape(0) / num_epochs << endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " sequence length: " << seq_length <<
endl << std::flush;
cout << " total number of samples: " << num_samples <<
endl << std::flush;
// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
// Begining offset for each document.
int32_t doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the begining of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
}
// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void *mem_) {
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(int32_t);
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
{2*byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done); // numpy array references
}
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length,
std::mt19937& rand32_gen) {
/* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1);
}
return max_length;
}
template<typename DocIdx>
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<int32_t>& sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const double short_seq_prob,
const int32_t seed,
const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int.
int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " short sequence probability: " << short_seq_prob <<
endl << std::flush;
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and it's length (1D).
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the seed so both iterations produce the same results.
std::mt19937 rand32_gen(seed);
// Set the flag on second iteration.
second = (iteration == 1);
// Counters:
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// Current map index.
uint64_t map_index = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
}
// For each document:
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent > 1) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have more than two sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
auto target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and if not only one sentence is left in the document.
// and if we have at least two sentneces.
// and if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow.
if ((3 * map_index + 2) >
std::numeric_limits<int64_t>::max()) {
cout << "number of samples exceeded maximum "
<< "allowed by type int64: "
<< std::numeric_limits<int64_t>::max()
<< endl;
throw std::overflow_error("Number of samples");
}
// Populate the map.
if (second) {
const auto map_index_0 = 3 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
}
// Update indices / counters.
++map_index;
prev_start_index = sent_index + 1;
target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[3*map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i;
const auto j0 = 3 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
{3*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const double short_seq_prob,
const int seed,
const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
}
}
template<typename DocIdx>
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& titles_sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const int32_t seed,
const bool verbose,
const bool use_one_sent_blocks) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
auto titles_sizes = titles_sizes_.unchecked<1>();
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and its length (1D).
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Acceptable number of sentences per block.
int min_num_sent = 2;
if (use_one_sent_blocks) {
min_num_sent = 1;
}
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the flag on second iteration.
second = (iteration == 1);
// Current map index.
uint64_t map_index = 0;
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
// assign every block a unique id
int32_t block_id = 0;
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
}
// For each document:
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
const auto target_seq_len = max_seq_length - titles_sizes[doc];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent >= min_num_sent) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have enough sentences and no long sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and there are an acceptable number of sentences left
// and if we have at least the minimum number of sentences.
// or if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent >= min_num_sent) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Populate the map.
if (second) {
const auto map_index_0 = 4 * map_index;
// Each sample has 4 items: the starting sentence index, ending sentence index,
// the index of the document from which the block comes (used for fetching titles)
// and the unique id of the block (used for creating block indexes)
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
}
// Update indices / counters.
++map_index;
++block_id;
prev_start_index = sent_index + 1;
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[4*map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 4 * i;
const auto j0 = 4 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{4*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const py::array_t<int>& titles_sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const int seed,
const bool verbose,
const bool use_one_sent_blocks) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
}
}
PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx", &build_sample_idx);
m.def("build_blending_indices", &build_blending_indices);
}

View File

@ -0,0 +1,156 @@
import itertools
import random
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
kwargs = dict(
name='full',
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
use_one_sent_docs=args.use_one_sent_docs
)
dataset = ICTDataset(**kwargs)
return dataset
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.rng = random.Random(self.seed)
self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return len(self.samples_mapping)
def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
sample_data = self.samples_mapping[idx]
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
if self.use_titles:
title = self.title_dataset[int(doc_idx)]
title_pad_offset = 3 + len(title)
else:
title = None
title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context query_in_block_prob fraction of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
query_mask = make_attention_mask(query_tokens, query_tokens)
context_mask = make_attention_mask(context_tokens, context_tokens)
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask,
'context_tokens': context_tokens,
'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data,
}
return sample
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
title = self.title_dataset[int(doc_idx)]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def get_null_block(self):
"""Get empty block and title - used in REALM pretraining"""
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def concat_and_pad_tokens(self, tokens, title=None):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens = list(tokens)
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
title = list(title)
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
assert len(tokens) <= self.max_seq_length
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return np.array(tokens), np.array(pad_mask)

View File

@ -0,0 +1,569 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
# An empty sentence no longer separates documents.
from functools import lru_cache
import os
import shutil
import struct
from itertools import accumulate
import numpy as np
import torch
def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
return np.uint16
else:
return np.int32
def get_available_dataset_impl():
return ['lazy', 'cached', 'mmap']
def infer_dataset_impl(path):
if IndexedDataset.exists(path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return 'cached'
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return 'mmap'
else:
return None
else:
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
def make_builder(out_file, impl, vocab_size=None):
if impl == 'mmap':
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, skip_warmup=False):
if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
if impl == 'infer':
impl = infer_dataset_impl(path)
if impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path)
elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None
def dataset_exists(path, impl):
if impl == 'mmap':
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
8: np.uint16
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + '.idx'
def data_file_path(prefix_path):
return prefix_path + '.bin'
def create_doc_idx(sizes):
doc_idx = [0]
for i, s in enumerate(sizes):
if s == 0:
doc_idx.append(i + 1)
return doc_idx
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00'
def __init__(self, path):
super().__init__()
self.path = path
self.data_file = None
self.read_index(path)
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self._len, self.s = struct.unpack('<QQ', f.read(16))
self.doc_count = struct.unpack('<Q', f.read(8))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
self.doc_idx = read_longs(f, self.doc_count)
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self._len:
raise IndexError('index out of range')
def __del__(self):
if self.data_file:
self.data_file.close()
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if not self.data_file:
self.read_data(self.path)
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return a
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
size = sum(sizes)
a = np.empty(size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[start] * self.element_size)
self.data_file.readinto(a)
offsets = list(accumulate(sizes))
sents = np.split(a, offsets[:-1])
return sents
def __len__(self):
return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path):
super().__init__(path)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx: ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
if self.data_file:
# close and delete data file after prefetch so we can pickle
self.data_file.close()
self.data_file = None
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx: ptx + a.size])
return a
elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary
sents = []
for i in range(*idx.indices(len(self))):
sents.append(self[i])
return sents
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
np.double: 8
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
self.doc_idx = [0]
def add_item(self, tensor):
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def end_document(self):
self.doc_idx.append(len(self.sizes))
def merge_file_(self, another_file):
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]
for offset in index.data_offsets[1:]:
self.data_offsets.append(begin + offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
with open(data_file_path(another_file), 'rb') as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
index.write(struct.pack('<Q', len(self.doc_idx)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
write_longs(index, self.doc_idx)
index.close()
def _warmup_mmap_file(path):
with open(path, 'rb') as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, 'wb')
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack('<Q', 1))
self._file.write(struct.pack('<B', code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack('<Q', len(sizes)))
self._file.write(struct.pack('<Q', len(doc_idx)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order='C'))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order='C'))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order='C'))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = struct.unpack('<Q', stream.read(8))
assert (1,) == version
dtype_code, = struct.unpack('<B', stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack('<Q', stream.read(8))[0]
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
print(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
print(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
print(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=size, offset=ptr)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
""" Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=length, offset=ptr)
return np_array
@property
def sizes(self):
return self._index.sizes
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb')
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)
def end_document(self):
self._doc_idx.append(len(self._sizes))
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
for size in index.sizes:
self._sizes.append(size)
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx)

View File

@ -0,0 +1,125 @@
# This file isn't really a formal automated test, it's just a place to
# put some code used during development and manual testing of
# indexed_dataset.
from megatron.data import indexed_dataset
from megatron.tokenizer import build_tokenizer
import argparse
import os
import sys
import torch
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../"))
def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
print(len(ds.doc_idx))
print(len(ds))
print(ds.doc_idx[-1])
if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds)))
if args.count > len(ds.doc_idx) - 1:
args.count = len(ds.doc_idx) - 1
for i in range(args.count):
start = ds.doc_idx[i]
end = ds.doc_idx[i + 1]
ids = ds[start:end]
print(f"Document {i}:")
print("--------------")
for s in ids:
assert len(s) > 0
l = s.data.tolist()
text = tokenizer.detokenize(l)
print(text)
print("---")
def test_indexed_dataset_get(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
size = ds.sizes[0]
print(f"size: {size}")
full = ds.get(0)
print(full)
# print(tokenizer.detokenize(full.data.tolist()))
print("---")
end = ds.get(0, offset=size - 10)
print(end)
# print(tokenizer.detokenize(end.data.tolist()))
start = ds.get(0, length=10)
print(start)
# print(tokenizer.detokenize(start.data.tolist()))
part = ds.get(0, offset=2, length=8)
print(part)
# print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
# # ds = AlbertDataset(idataset, tokenizer)
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
# args.epochs, args.max_num_samples,
# args.masked_lm_prob, args.seq_length,
# args.short_seq_prob, args.seed)
# truncated = 0
# total = 0
# for i, s in enumerate(ds):
# ids = s['text']
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
# print(tokens)
# if i >= args.count-1:
# exit()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files')
parser.add_argument('--dataset-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--count', type=int, default=10,
help='Number of samples/documents to print')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
parser.add_argument('--epochs', type=int, default=5,
help='Number of epochs to plan for')
parser.add_argument('--max-num-samples', type=int, default=None,
help='Maximum number of samples to plan for')
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
help='probability of masking tokens')
parser.add_argument('--seq-length', type=int, default=512,
help='maximum sequence length')
parser.add_argument('--short-seq-prob', type=float, default=0.1,
help='probability of creating a short sequence')
parser.add_argument('--seed', type=int, default=1234,
help='random seed')
args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
# test_albert_dataset(args)
test_indexed_dataset_get(args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,10 @@
#!/bin/bash
IMPL=cached
python ../preprocess_data.py \
--input test_samples.json \
--vocab vocab.txt \
--dataset-impl ${IMPL} \
--output-prefix test_samples_${IMPL} \
--workers 1 \
--log-interval 2

View File

@ -0,0 +1,38 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer import build_tokenizer
_TOKENIZER = None
_PADDED_VOCAB_SIZE = -1
def initialize_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
tokenizer, padded_vocab_size = build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids)
global _TOKENIZER, _PADDED_VOCAB_SIZE
_TOKENIZER = tokenizer
_PADDED_VOCAB_SIZE = padded_vocab_size
def get_tokenizer():
global _TOKENIZER
return _TOKENIZER
def get_padded_vocab_size():
global _PADDED_VOCAB_SIZE
return _PADDED_VOCAB_SIZE

View File

@ -0,0 +1,431 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenization."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
@staticmethod
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
""" Converts a sequence of tokens (string) in a single string. """
def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts
like spaces before punctuations and abbreviated forms.
"""
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
text = ' '.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
clean_text = clean_up_tokenization(text)
return clean_text
else:
return text
def vocab_size(self):
return len(self.vocab)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False

View File

@ -0,0 +1,256 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron tokenizers."""
from abc import ABC
from abc import abstractmethod
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from .bert_tokenization import FullTokenizer as FullBertTokenizer
def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0):
"""Initialize tokenizer."""
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
print('> building {} tokenizer ...'.format(tokenizer_type),
flush=True)
# Select and instantiate the tokenizer.
if tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file,
lower_case=True,
vocab_extra_ids=vocab_extra_ids)
elif tokenizer_type == 'BertWordPieceCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file,
lower_case=False,
vocab_extra_ids=vocab_extra_ids)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(tokenizer_type))
# Add vocab size.
padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size)
return tokenizer, padded_vocab_size
def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128):
"""Pad vocab size so it is divisible by model parallel size and
still having GPU friendly size."""
after = orig_vocab_size
if gpc.is_initialized(ParallelMode.TENSOR):
multiple = make_vocab_size_divisible_by * gpc.get_world_size(ParallelMode.TENSOR)
else:
multiple = make_vocab_size_divisible_by
while (after % multiple) != 0:
after += 1
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
print(' > padded vocab (size: {}) with {} dummy tokens '
'(new size: {})'.format(
orig_vocab_size, after - orig_vocab_size, after), flush=True)
return after
class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""
def __init__(self, name):
self.name = name
super().__init__()
@property
@abstractmethod
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
'tokenizer'.format(self.name))
@property
def sep(self):
raise NotImplementedError('SEP is not provided for {} '
'tokenizer'.format(self.name))
@property
def pad(self):
raise NotImplementedError('PAD is not provided for {} '
'tokenizer'.format(self.name))
@property
def eod(self):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0):
if lower_case:
name = 'BERT Lower Case'
else:
name = 'BERT Upper Case'
super().__init__(name)
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case)
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
self._additional_special_tokens = []
# (dsachan) Add BOS and EOS tokens
SPECIAL_TOKENS = {'eos_token': '[EOS]',
'bos_token': '[BOS]'}
self._bos_token = '[BOS]'
self.add_token(self._bos_token)
self._bos_token_id = self.vocab.get(self._bos_token)
self._eos_token = '[EOS]'
self.add_token(self._eos_token)
self._eos_token_id = self.vocab.get(self._eos_token)
# (dsachan) Add additional special tokens
# These can be used as sentinel tokens in T5 model inputs
additional_special_tokens = []
additional_special_tokens.extend(
["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)])
self.add_additional_special_tokens(additional_special_tokens)
def add_token(self, token):
if token not in self.vocab:
self.inv_vocab[self.vocab_size] = token
# self.vocab_size comes from len(vocab)
# and it will increase as we add elements
self.vocab[token] = self.vocab_size
def add_additional_special_tokens(self, tokens_list):
setattr(self, "additional_special_tokens", tokens_list)
for value in tokens_list:
self.add_token(value)
@property
def vocab_size(self):
return self.tokenizer.vocab_size()
@property
def vocab(self):
return self.tokenizer.vocab
@property
def inv_vocab(self):
return self.tokenizer.inv_vocab
def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode(self, ids):
tokens = self.tokenizer.convert_ids_to_tokens(ids)
return self.tokenizer.convert_tokens_to_string(tokens)
def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list]
result = ""
for s in non_pads:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
@property
def cls(self):
return self.cls_id
@property
def sep(self):
return self.sep_id
@property
def pad(self):
return self.pad_id
@property
def mask(self):
return self.mask_id
@property
def bos_token(self):
""" Beginning of sentence token id """
return self._bos_token
@property
def eos_token(self):
""" End of sentence token id """
return self._eos_token
@property
def additional_special_tokens(self):
""" All the additional special tokens you may want to use (list of strings)."""
return self._additional_special_tokens
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary."""
return self._bos_token_id
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary."""
return self._eos_token_id
@property
def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers)."""
return [self.vocab.get(token) for token in self._additional_special_tokens]
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value

View File

@ -0,0 +1,41 @@
import torch
import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger
import torch.nn.functional as F
import torch.distributed as dist
from .cross_entropy import vocab_cross_entropy
class BertLoss(nn.Module):
def forward(self,
lm_loss,
sop_logits,
loss_mask,
sentence_order):
lm_loss_ = lm_loss.float()
loss_mask = loss_mask.float()
loss_mask_sum = loss_mask.sum()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1))
lm_loss /= loss_mask_sum
torch.distributed.all_reduce(
lm_loss,
group=gpc.get_group(ParallelMode.SEQUENCE)
)
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE)
else:
sop_loss = None
loss = lm_loss
return loss

View File

@ -0,0 +1,75 @@
from colossalai.context.parallel_mode import ParallelMode
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
class _VocabCrossEntropy(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, vocab_parallel_logits, target):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = target < 0
masked_target = target.clone()
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1))
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (
1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
def vocab_cross_entropy(vocab_logits, target):
"""helper function for the cross entropy."""
return _VocabCrossEntropy.apply(vocab_logits, target)

View File

@ -0,0 +1,55 @@
import torch
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions,
contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank, world_size):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size)

View File

@ -0,0 +1 @@
from .annealing_lr import AnnealingLR

View File

@ -0,0 +1,158 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Learning rate decay functions."""
import math
class AnnealingLR(object):
"""Anneals the learning rate."""
def __init__(self,
optimizer,
max_lr,
min_lr,
warmup_steps,
decay_steps,
decay_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
# Class values.
self.optimizer = optimizer
self.max_lr = float(max_lr)
self.min_lr = min_lr
assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps
self.num_steps = 0
self.decay_steps = decay_steps
assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.'
# Set the learning rate
self.step(0)
def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps)
# If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant':
return self.max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps:
return self.min_lr
# If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr
if self.decay_style == 'linear':
coeff = (1.0 - decay_ratio)
elif self.decay_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else:
raise Exception('{} decay style is not supported.'.format(
self.decay_style))
return self.min_lr + coeff * delta_lr
def step(self, increment=1):
"""Set lr for all parameters groups."""
self.num_steps += increment
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
def state_dict(self):
state_dict = {
'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps,
'num_steps': self.num_steps,
'decay_style': self.decay_style,
'decay_steps': self.decay_steps,
'min_lr': self.min_lr
}
return state_dict
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_lr_scheduler:
return cls_value
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match'
return sd_value
def load_state_dict(self, sd):
if 'start_lr' in sd:
max_lr_ = sd['start_lr']
else:
max_lr_ = sd['max_lr']
self.max_lr = self._check_and_set(self.max_lr, max_lr_,
'learning rate')
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate')
if 'warmup_iter' in sd:
warmup_steps_ = sd['warmup_iter']
else:
warmup_steps_ = sd['warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps,
warmup_steps_,
'warmup iterations')
if 'end_iter' in sd:
decay_steps_ = sd['end_iter']
else:
decay_steps_ = sd['decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_,
'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'],
'decay style')
if 'num_iters' in sd:
num_steps = sd['num_iters']
else:
num_steps = sd['num_steps']
self.step(increment=num_steps)

View File

@ -0,0 +1,2 @@

View File

@ -0,0 +1,282 @@
from colossalai.context.parallel_mode import ParallelMode
import torch
import torch.nn as nn
import inspect
from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding
from .layers.init_method import init_normal, output_init_normal
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.kernel import LayerNorm
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.utils import partition_uniform
class BertForPretrain(nn.Module):
def __init__(self,
vocab_size,
hidden_size,
max_sequence_length,
num_attention_heads,
num_layers,
add_binary_head,
is_naive_fp16,
num_tokentypes=2,
dropout_prob=0.1,
mlp_ratio=4,
init_std=0.02,
convert_fp16_to_fp32_in_softmax=False,
):
super().__init__()
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
self.init_std = init_std
self.num_layers = num_layers
if not add_binary_head:
num_tokentypes = 0
self.preprocessor = PreProcessor(self.sub_seq_length)
self.embedding = Embedding(hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=dropout_prob,
num_tokentypes=num_tokentypes)
self.bert_layers = nn.ModuleList()
for i in range(num_layers):
bert_layer = BertLayer(layer_number=i+1,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=dropout_prob,
mlp_ratio=mlp_ratio,
hidden_dropout=dropout_prob,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
is_naive_fp16=is_naive_fp16
)
self.bert_layers.append(bert_layer)
self.layer_norm = LayerNorm(hidden_size)
self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0),
add_binary_head=add_binary_head)
self.reset_parameters()
def _init_normal(self, tensor):
init_normal(tensor, sigma=self.init_std)
def _output_init_normal(self, tensor):
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
def reset_parameters(self):
# initialize embedding
self._init_normal(self.embedding.word_embedding_weight)
self._init_normal(self.embedding.position_embeddings.weight)
if self.embedding.tokentype_embeddings:
self._init_normal(self.embedding.tokentype_embeddings.weight)
# initialize bert layer
for layer in self.bert_layers:
# initialize self attention
self._init_normal(layer.self_attention.query_key_value.weight)
self._output_init_normal(layer.self_attention.dense.weight)
self._init_normal(layer.mlp.dense_h_to_4h.weight)
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
# initializer head
self._init_normal(self.head.lm_head.dense.weight)
if self.head.binary_head is not None:
self._init_normal(self.head.binary_head.pooler.dense.weight)
self._init_normal(self.head.binary_head.dense.weight)
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
# inputs of the forward function
# input_ids: [batch_size, sub_seq_len]
# attention_mask: [batch_size, seq_len]
# tokentype_ids: [batch_size, sub_seq_len]
# outputs of preprocessor
# pos_ids: [batch_size, sub_seq_len]
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
# hidden_states shape change:
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
hidden_states = hidden_states.transpose(0, 1).contiguous()
for idx, layer in enumerate(self.bert_layers):
hidden_states = layer(hidden_states, attention_masks)
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.layer_norm(hidden_states)
# hidden_states: [sub_seq_len, batch_size, hidden_size]
# word_embedding: [vocab_size, hidden_size]
return self.head(output, self.embedding.word_embedding_weight, lm_labels)
class PipelineBertForPretrain(nn.Module):
def __init__(self,
vocab_size,
hidden_size,
max_sequence_length,
num_attention_heads,
num_layers,
add_binary_head,
is_naive_fp16,
num_tokentypes=2,
dropout_prob=0.1,
mlp_ratio=4,
init_std=0.02,
convert_fp16_to_fp32_in_softmax=False,
first_stage=True,
last_stage=True,
start_idx=None,
end_idx=None):
super().__init__()
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
self.init_std = init_std
self.num_layers = num_layers
if not add_binary_head:
num_tokentypes = 0
self.first_stage = first_stage
self.last_stage = last_stage
self.preprocessor = PreProcessor(self.sub_seq_length)
if self.first_stage:
self.embedding = Embedding(hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=dropout_prob,
num_tokentypes=num_tokentypes)
# transformer layers
self.bert_layers = nn.ModuleList()
if start_idx is None and end_idx is None:
start_idx = 0
end_idx = num_layers
for i in range(start_idx, end_idx):
bert_layer = BertLayer(layer_number=i+1,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=dropout_prob,
mlp_ratio=mlp_ratio,
hidden_dropout=dropout_prob,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
is_naive_fp16=is_naive_fp16
)
self.bert_layers.append(bert_layer)
if self.last_stage:
self.word_embeddings = VocabEmbedding(vocab_size, hidden_size)
self.layer_norm = LayerNorm(hidden_size)
self.head = BertDualHead(hidden_size, vocab_size,
add_binary_head=add_binary_head)
self.reset_parameters()
def _init_normal(self, tensor):
init_normal(tensor, sigma=self.init_std)
def _output_init_normal(self, tensor):
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
def reset_parameters(self):
# initialize embedding
if self.first_stage:
self._init_normal(self.embedding.word_embedding_weight)
self._init_normal(self.embedding.position_embeddings.weight)
if self.embedding.tokentype_embeddings:
self._init_normal(self.embedding.tokentype_embeddings.weight)
# initialize bert layer
for layer in self.bert_layers:
# initialize self attention
self._init_normal(layer.self_attention.query_key_value.weight)
self._output_init_normal(layer.self_attention.dense.weight)
self._init_normal(layer.mlp.dense_h_to_4h.weight)
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
# initializer head
if self.last_stage:
self._init_normal(self.head.lm_head.dense.weight)
if self.head.binary_head is not None:
self._init_normal(self.head.binary_head.pooler.dense.weight)
self._init_normal(self.head.binary_head.dense.weight)
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
# inputs of the forward function
# input_ids: [batch_size, sub_seq_len]
# attention_mask: [batch_size, seq_len]
# tokentype_ids: [batch_size, sub_seq_len]
# outputs of preprocessor
# pos_ids: [batch_size, sub_seq_len]
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
if self.first_stage:
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
else:
_, attention_masks = self.preprocessor(None, attention_masks)
if self.first_stage:
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
hidden_states = input_ids
# hidden_states shape change:
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
for idx, layer in enumerate(self.bert_layers):
hidden_states = layer(hidden_states, attention_masks)
if self.last_stage:
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.layer_norm(hidden_states)
output = self.head(output, self.word_embeddings.weight, lm_labels)
else:
output = hidden_states
# hidden_states: [sub_seq_len, batch_size, hidden_size]
# word_embedding: [vocab_size, hidden_size]
return output
def _filter_kwargs(func, kwargs):
sig = inspect.signature(func)
return {k: v for k, v in kwargs.items() if k in sig.parameters}
def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
logger = get_dist_logger()
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
rank = gpc.get_global_rank()
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
models = []
for start, end in parts:
kwargs['num_layers'] = num_layers
kwargs['start_idx'] = start
kwargs['end_idx'] = end
kwargs['first_stage'] = start == 0
kwargs['last_stage'] = end == num_layers
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device)
if start == 0:
wrapper.register_module(chunk.embedding.word_embeddings)
elif end == num_layers:
wrapper.register_module(chunk.word_embeddings)
models.append(chunk)
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
return model

View File

@ -0,0 +1,4 @@
from .embedding import VocabEmbedding, Embedding
from .bert_layer import BertLayer
from .head import BertDualHead
from .preprocess import PreProcessor

View File

@ -0,0 +1,118 @@
import torch
import torch.nn as nn
from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing
from colossalai.kernel.jit import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
from colossalai.kernel.cuda_native import LayerNorm
from .mlp import TransformerMLP
from .dropout import get_bias_dropout_add
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
class BertLayer(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self,
layer_number,
hidden_size,
num_attention_heads,
attention_dropout,
mlp_ratio,
hidden_dropout,
is_naive_fp16,
apply_residual_connection_post_layernorm=False,
fp32_residual_connection=False,
bias_dropout_fusion: bool = True,
convert_fp16_to_fp32_in_softmax: bool = False):
super().__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.fp32_residual_connection = fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size)
# Self attention.
self.self_attention = TransformerSelfAttentionRing(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=attention_dropout,
attention_mask_func=attention_mask_func,
layer_number=layer_number,
apply_query_key_layer_scaling=True,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
fp16=is_naive_fp16
)
self.hidden_dropout = hidden_dropout
self.bias_dropout_fusion = bias_dropout_fusion
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(hidden_size)
self.mlp = TransformerMLP(hidden_size=hidden_size, mlp_ratio=mlp_ratio)
def forward(self, hidden_states, attention_mask):
# hidden_states: [batch_size, sub_seq_len, hidden_size]
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(layernorm_output, attention_mask)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
return output

View File

@ -0,0 +1,13 @@
import torch
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add

View File

@ -0,0 +1,96 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class VocabEmbedding(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super(VocabEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Allocate weights and initialize.
self.weight = nn.Parameter(torch.empty(
self.num_embeddings, self.embedding_dim))
init.xavier_uniform_(self.weight)
def forward(self, hidden_state):
output = F.embedding(hidden_state, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
return output
def __repr__(self):
return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \
f'embedding_dim={self.embedding_dim})'
class Embedding(nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
num_tokentypes):
super(Embedding, self).__init__()
self.hidden_size = hidden_size
self.num_tokentypes = num_tokentypes
self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size)
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.hidden_size)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
@property
def word_embedding_weight(self):
return self.word_embeddings.weight
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None and self.tokentype_embeddings is not None:
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings

View File

@ -0,0 +1,78 @@
import colossalai
import torch
import torch.nn as nn
import torch.nn.functional as F
from .pooler import Pooler
from .linear import Linear
from .embedding import VocabEmbedding
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.kernel import LayerNorm
from loss_func.cross_entropy import vocab_cross_entropy
class BertLMHead(nn.Module):
"""Masked LM head for Bert
Arguments:
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
"""
def __init__(self,
vocab_size,
hidden_size,
):
super(BertLMHead, self).__init__()
self.bias = torch.nn.Parameter(torch.zeros(vocab_size))
self.dense = Linear(hidden_size, hidden_size)
self.layernorm = LayerNorm(hidden_size)
self.gelu = torch.nn.functional.gelu
def forward(self, hidden_states, word_embeddings_weight, lm_labels):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = F.linear(hidden_states, word_embeddings_weight, self.bias)
lm_loss = vocab_cross_entropy(output, lm_labels)
return lm_loss
class BertBinaryHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.pooler = Pooler(hidden_size)
self.dense = Linear(hidden_size, 2)
def forward(self, hidden_states):
if gpc.get_local_rank(ParallelMode.SEQUENCE) == 0:
output = self.pooler(hidden_states)
output = self.dense(output)
else:
output = None
return output
class BertDualHead(nn.Module):
def __init__(self, hidden_size, vocab_size, add_binary_head):
super().__init__()
self.lm_head = BertLMHead(vocab_size, hidden_size)
self.add_binary_head = add_binary_head
if add_binary_head:
self.binary_head = BertBinaryHead(hidden_size)
else:
self.binary_head = None
def forward(self, hidden_states, word_embeddings_weight, lm_labels):
if self.add_binary_head:
binary_output = self.binary_head(hidden_states)
else:
binary_output = None
lm_loss = self.lm_head(hidden_states, word_embeddings_weight, lm_labels)
return lm_loss, binary_output

View File

@ -0,0 +1,12 @@
import torch
import math
def init_normal(tensor, sigma):
"""Init method based on N(0, sigma)."""
torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
def output_init_normal(tensor, sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
torch.nn.init.normal_(tensor, mean=0.0, std=std)

View File

@ -0,0 +1,63 @@
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
import torch.nn.init as init
class Linear(nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(self,
input_size,
output_size,
bias=True,
skip_bias_add=False):
super(Linear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
self.weight = Parameter(torch.empty(self.output_size,
self.input_size,
))
init.normal_(self.weight)
if bias:
self.bias = Parameter(torch.empty(self.output_size))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output = F.linear(input_, self.weight, bias)
if self.skip_bias_add:
return output, self.bias
else:
return output
def __repr__(self):
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'

View File

@ -0,0 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .linear import Linear
from colossalai.kernel.jit import bias_gelu_impl
class TransformerMLP(nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
"""
def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True):
super(TransformerMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = Linear(
hidden_size,
int(hidden_size*mlp_ratio),
skip_bias_add=True)
self.bias_gelu_fusion = fuse_gelu
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = Linear(
int(hidden_size*mlp_ratio),
hidden_size,
skip_bias_add=True)
def forward(self, hidden_states):
# hidden states should be in the shape of [s, b, h]
# it will be projects into [s, b, 4h]
# and projected back to [s, b, h]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
intermediate_parallel = \
bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias

View File

@ -0,0 +1,28 @@
import torch
import torch.nn as nn
from .linear import Linear
class Pooler(nn.Module):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size):
super(Pooler, self).__init__()
self.dense = Linear(hidden_size, hidden_size)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled

View File

@ -0,0 +1,58 @@
from colossalai.context.parallel_mode import ParallelMode
import torch
import torch.nn as nn
from colossalai.core import global_context as gpc
class PreProcessor(nn.Module):
def __init__(self, sub_seq_length):
super().__init__()
self.sub_seq_length = sub_seq_length
def bert_position_ids(self, token_ids):
# Create position ids
seq_length = token_ids.size(1)
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
position_ids = torch.arange(seq_length*local_rank,
seq_length * (local_rank+1),
dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
def bert_extended_attention_mask(self, attention_mask):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
start_index = local_rank * self.sub_seq_length
end_index = (local_rank + 1) * self.sub_seq_length
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s/D, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
attention_mask_bss = attention_mask_bss[:, start_index:end_index, :]
# [b, 1, s/D, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = (extended_attention_mask < 0.5)
return extended_attention_mask
def forward(self, input_ids=None, attention_mask=None):
if attention_mask is not None:
extended_attention_mask = self.bert_extended_attention_mask(attention_mask)
else:
extended_attention_mask = None
if input_ids is not None:
position_ids = self.bert_position_ids(input_ids)
else:
position_ids = None
return position_ids, extended_attention_mask

View File

@ -0,0 +1,210 @@
import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from data import build_train_valid_test_data_iterators
from data.tokenizer import initialize_tokenizer, get_padded_vocab_size
from data.bert_helper import get_batch_for_sequence_parallel, SequenceParallelDataIterator
from colossalai.amp import AMP_TYPE
from colossalai.logging import get_dist_logger
from colossalai.utils import MultiTimer, is_using_pp
from model.bert import BertForPretrain
from lr_scheduler import AnnealingLR
from loss_func.bert_loss import BertLoss
import torch
from colossalai.engine.schedule import PipelineSchedule
from colossalai.amp import AMP_TYPE
from colossalai.nn.optimizer import FusedAdam
from colossalai.kernel import LayerNorm
from model.bert import build_pipeline_bert
def process_batch_data(batch_data):
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = batch_data
if gpc.is_first_rank(ParallelMode.PIPELINE):
data = dict(input_ids=tokens, attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
else:
data = dict(attention_masks=padding_mask, tokentype_ids=types, lm_labels=lm_labels)
label = dict(loss_mask=loss_mask, sentence_order=sentence_order)
return data, label
def main():
# initialize
colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl')
logger = get_dist_logger()
# build dataloader
initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
VOCAB_SIZE = get_padded_vocab_size()
trainloader, validloader, testloader = build_train_valid_test_data_iterators(
train_iters=gpc.config.TRAIN_ITERS,
global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
eval_interval=gpc.config.EVAL_INTERVAL,
eval_iters=gpc.config.EVAL_ITERS,
data_prefix=[gpc.config.DATA_PATH],
data_impl='mmap',
splits_string='949,50,1',
max_seq_length=gpc.config.SEQ_LENGTH,
masked_lm_prob=0.15,
short_seq_prob=0.1,
seed=1234,
skip_warmup=True,
binary_head=False,
)
logger.info("Dataloaders are built", ranks=[0])
# build model
if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE:
is_naive_fp16 = True
else:
is_naive_fp16 = False
use_pipeline = is_using_pp()
kwargs = dict(vocab_size=VOCAB_SIZE,
hidden_size=gpc.config.HIDDEN_SIZE,
max_sequence_length=gpc.config.SEQ_LENGTH,
num_attention_heads=gpc.config.NUM_ATTENTION_HEADS,
convert_fp16_to_fp32_in_softmax=True,
is_naive_fp16=is_naive_fp16,
add_binary_head=gpc.config.ADD_BINARY_HEAD)
if use_pipeline:
model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs)
else:
model = BertForPretrain(num_layers=gpc.config.DEPTH, **kwargs)
model = model.half()
model.reset_parameters()
logger.info(f"Model is built with softmax in fp32 = {is_naive_fp16}", ranks=[0])
total_numel = 0
for p in model.parameters():
total_numel += p.numel()
logger.info(f"This model has {total_numel} parameters")
# build criterion
criterion = BertLoss()
logger.info("Criterion is built", ranks=[0])
# layernorm and bias has no weight decay
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in model.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias'])
logger.info(
f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}"
)
# optimizer
optimizer = FusedAdam((weight_decay_params, no_weight_decay_params),
lr=gpc.config.LR,
weight_decay=gpc.config.WEIGHT_DECAY)
logger.info("Optimizer is built", ranks=[0])
# lr scheduler
# follow Megatron-LM setting
warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION)
lr_scheduler = AnnealingLR(optimizer=optimizer,
max_lr=gpc.config.LR,
min_lr=gpc.config.MIN_LR,
warmup_steps=warmup_steps,
decay_steps=gpc.config.DECAY_ITERS,
decay_style='linear')
logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps")
# # init
engine, *dummy = colossalai.initialize(
model,
optimizer,
criterion,
)
# build timer
timer = MultiTimer()
skip_iters = 0
# build loss tracker
accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda()
accumulated_eval_loss = torch.zeros(1, dtype=torch.float32).cuda()
# build data iters for pipeline parallel
if use_pipeline:
train_data_iter = SequenceParallelDataIterator(trainloader)
valid_data_iter = SequenceParallelDataIterator(validloader)
for step in range(1, gpc.config.TRAIN_ITERS + 1):
timer.start('train-iterations')
engine.train()
if use_pipeline:
engine.zero_grad()
_, _, train_loss = engine.execute_schedule(train_data_iter, return_output_label=False)
engine.step()
else:
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
trainloader)
engine.zero_grad()
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
engine.backward(train_loss)
engine.step()
timer.stop('train-iterations', keep_in_history=True)
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_train_loss += train_loss
lr_scheduler.step()
if step % gpc.config.EVAL_INTERVAL == 0:
engine.eval()
for j in range(gpc.config.EVAL_ITERS):
with torch.no_grad():
if use_pipeline:
_, _, eval_loss = engine.execute_schedule(valid_data_iter,
forward_only=True,
return_output_label=False)
else:
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel(
validloader)
lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels)
eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order)
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_eval_loss += eval_loss
if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE):
accumulated_eval_loss /= gpc.config.EVAL_ITERS
accumulated_train_loss /= gpc.config.EVAL_INTERVAL
timer_string = []
for n, t in timer:
timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}")
timer_string = ' | '.join(timer_string)
lr = list(engine.optimizer.param_groups)[0]['lr']
loss_scale = engine.optimizer.optim.loss_scale.item()
if gpc.is_initialized(ParallelMode.PIPELINE):
ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]]
else:
ranks = [0]
logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' +
f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' +
f"| Learning rate: {lr} | " + timer_string,
ranks=ranks)
for n, t in timer:
t.reset()
accumulated_eval_loss.zero_()
accumulated_train_loss.zero_()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,82 @@
Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
CreativeML Open RAIL-M
dated August 22, 2022
Section I: PREAMBLE
Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
NOW THEREFORE, You and Licensor agree as follows:
1. Definitions
- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
Section II: INTELLECTUAL PROPERTY RIGHTS
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
You must cause any modified files to carry prominent notices stating that You changed the files;
You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
Section IV: OTHER PROVISIONS
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
8. Trademarks and related. Nothing in this License permits You to make use of Licensors trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
END OF TERMS AND CONDITIONS
Attachment A
Use Restrictions
You agree not to use the Model or Derivatives of the Model:
- In any way that violates any applicable national, federal, state, local or international law or regulation;
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
- To generate or disseminate personal identifiable information that can be used to harm an individual;
- To defame, disparage or otherwise harass others;
- For fully automated decision making that adversely impacts an individuals legal rights or otherwise creates or modifies a binding, enforceable obligation;
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
- To provide medical advice and medical results interpretation;
- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).

View File

@ -0,0 +1,115 @@
# Handson 6: Acceleration of Stable Diffusion
*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to exploit multiple optimization strategies
, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs.
## Stable Diffusion
[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion
model.
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
<p id="diffusion_train" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/diffusion_train.png" width=800/>
</p>
[Stable Diffusion with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) provides **6.5x faster training and pretraining cost saving, the hardware cost of fine-tuning can be almost 7X cheaper** (from RTX3090/4090 24GB to RTX3050/2070 8GB).
<p id="diffusion_demo" align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/diffusion_demo.png" width=800/>
</p>
## Requirements
A suitable [conda](https://conda.io/) environment named `ldm` can be created
and activated with:
```
conda env create -f environment.yaml
conda activate ldm
```
You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
```
conda install pytorch torchvision -c pytorch
pip install transformers==4.19.2 diffusers invisible-watermark
pip install -e .
```
### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website
```
pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
```
### Install [Lightning](https://github.com/Lightning-AI/lightning)
We use the Sep. 2022 version with commit id as `b04a7aa`.
```
git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
pip install -r requirements.txt && pip install .
```
> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future.
## Dataset
The DataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
you should the change the `data.file_path` in the `config/train_colossalai.yaml`
## Training
we provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml`, `train_ddp.yaml`
for example, you can run the training from colossalai by
```
python main.py --logdir /tmp -t --postfix test -b config/train_colossalai.yaml
```
- you can change the `--logdir` the save the log information and the last checkpoint
### Training config
you can change the trainging config in the yaml file
- accelerator: acceleratortype, default 'gpu'
- devices: device number used for training, default 4
- max_epochs: max training epochs
- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
## Comments
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch),
[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion).
Thanks for open-sourcing!
- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch).
## BibTeX
```
@article{bian2021colossal,
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
journal={arXiv preprint arXiv:2110.14883},
year={2021}
}
@misc{rombach2021highresolution,
title={High-Resolution Image Synthesis with Latent Diffusion Models},
author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
year={2021},
eprint={2112.10752},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@article{dao2022flashattention,
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
journal={arXiv preprint arXiv:2205.14135},
year={2022}
}
```

View File

@ -0,0 +1,116 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
use_fp16: True
data:
target: main.DataModuleFromConfig
params:
batch_size: 64
wrap: False
train:
target: ldm.data.base.Txt2ImgIterableBaseDataset
params:
file_path: "/data/scratch/diffuser/laion_part0/"
world_size: 1
rank: 0
lightning:
trainer:
accelerator: 'gpu'
devices: 4
log_gpu_memory: all
max_epochs: 2
precision: 16
auto_select_gpus: False
strategy:
target: pytorch_lightning.strategies.ColossalAIStrategy
params:
use_chunk: False
enable_distributed_storage: True,
placement_policy: cuda
force_outputs_fp32: False
log_every_n_steps: 2
logger: True
default_root_dir: "/tmp/diff_log/"
profiler: pytorch
logger_config:
wandb:
target: pytorch_lightning.loggers.WandbLogger
params:
name: nowname
save_dir: "/tmp/diff_log/"
offline: opt.debug
id: nowname

View File

@ -0,0 +1,113 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 32
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 100 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
use_fp16: True
data:
target: main.DataModuleFromConfig
params:
batch_size: 64
wrap: False
train:
target: ldm.data.base.Txt2ImgIterableBaseDataset
params:
file_path: "/data/scratch/diffuser/laion_part0/"
world_size: 1
rank: 0
lightning:
trainer:
accelerator: 'gpu'
devices: 4
log_gpu_memory: all
max_epochs: 2
precision: 16
auto_select_gpus: False
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: False
log_every_n_steps: 2
# max_steps: 6o
logger: True
default_root_dir: "/tmp/diff_log/"
# profiler: pytorch
logger_config:
wandb:
target: pytorch_lightning.loggers.WandbLogger
params:
name: nowname
save_dir: "/tmp/diff_log/"
offline: opt.debug
id: nowname

View File

@ -0,0 +1,121 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 32
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
check_nan_inf: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1.e-4 ]
f_min: [ 1.e-10 ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: False
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
use_fp16: True
data:
target: main.DataModuleFromConfig
params:
batch_size: 32
wrap: False
train:
target: ldm.data.pokemon.PokemonDataset
# params:
# file_path: "/data/scratch/diffuser/laion_part0/"
# world_size: 1
# rank: 0
lightning:
trainer:
accelerator: 'gpu'
devices: 4
log_gpu_memory: all
max_epochs: 2
precision: 16
auto_select_gpus: False
strategy:
target: pytorch_lightning.strategies.ColossalAIStrategy
params:
use_chunk: False
enable_distributed_storage: True,
placement_policy: cuda
force_outputs_fp32: False
initial_scale: 65536
min_scale: 1
max_scale: 65536
# max_scale: 4294967296
log_every_n_steps: 2
logger: True
default_root_dir: "/tmp/diff_log/"
profiler: pytorch
logger_config:
wandb:
target: pytorch_lightning.loggers.WandbLogger
params:
name: nowname
save_dir: "/tmp/diff_log/"
offline: opt.debug
id: nowname

View File

@ -0,0 +1,32 @@
name: ldm
channels:
- pytorch
- defaults
dependencies:
- python=3.9.12
- pip=20.3
- cudatoolkit=11.3
- pytorch=1.11.0
- torchvision=0.12.0
- numpy=1.19.2
- pip:
- albumentations==0.4.3
- diffusers
- opencv-python==4.6.0.66
- pudb==2019.2
- invisible-watermark
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- pytorch-lightning==1.4.2
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit>=0.73.1
- einops==0.3.0
- torch-fidelity==0.3.0
- transformers==4.19.2
- torchmetrics==0.6.0
- kornia==0.6
- prefetch_generator
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e .

View File

@ -0,0 +1,75 @@
import math
from abc import abstractmethod
import torch
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
import os
import numpy as np
import cv2
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, file_path: str, rank, world_size):
super().__init__()
self.file_path = file_path
self.folder_list = []
self.file_list = []
self.txt_list = []
self.info = self._get_file_info(file_path)
self.start = self.info['start']
self.end = self.info['end']
self.rank = rank
self.world_size = world_size
# self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
# self.iter_start = self.start + self.rank * self.per_worker
# self.iter_end = min(self.iter_start + self.per_worker, self.end)
# self.num_records = self.iter_end - self.iter_start
# self.valid_ids = [i for i in range(self.iter_end)]
self.num_records = self.end - self.start
self.valid_ids = [i for i in range(self.end)]
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
def __len__(self):
# return self.iter_end - self.iter_start
return self.end - self.start
def __iter__(self):
sample_iterator = self._sample_generator(self.start, self.end)
# sample_iterator = self._sample_generator(self.iter_start, self.iter_end)
return sample_iterator
def _sample_generator(self, start, end):
for idx in range(start, end):
file_name = self.file_list[idx]
txt_name = self.txt_list[idx]
f_ = open(txt_name, 'r')
txt_ = f_.read()
f_.close()
image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = torch.from_numpy(image) / 255
yield {"caption": txt_, "image":image}
def _get_file_info(self, file_path):
info = \
{
"start": 1,
"end": 0,
}
self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i]
for folder in self.folder_list:
files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i]
txts = [k.replace('jpg', 'txt') for k in files]
self.file_list.extend(files)
self.txt_list.extend(txts)
info['end'] = len(self.file_list)
# with open(file_path, 'r') as fin:
# for _ in enumerate(fin):
# info['end'] += 1
# self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list]
return info

View File

@ -0,0 +1,394 @@
import os, yaml, pickle, shutil, tarfile, glob
import cv2
import albumentations
import PIL
import numpy as np
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from functools import partial
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, Subset
import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
from taming.data.imagenet import ImagePaths
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
def synset2idx(path_to_yaml="data/index_synset.yaml"):
with open(path_to_yaml) as f:
di2s = yaml.load(f)
return dict((v,k) for k,v in di2s.items())
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare()
self._prepare_synset_to_human()
self._prepare_idx_to_synset()
self._prepare_human_to_integer_label()
self._load()
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i]
def _prepare(self):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = []
for rpath in relpaths:
syn = rpath.split("/")[0]
if syn in synsets:
files.append(rpath)
return files
else:
return relpaths
def _prepare_synset_to_human(self):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
if (not os.path.exists(self.human2integer)):
download(URL, self.human2integer)
with open(self.human2integer, "r") as f:
lines = f.read().splitlines()
assert len(lines) == 1000
self.human2integer_dict = dict()
for line in lines:
value, key = line.split(":")
self.human2integer_dict[key] = int(value)
def _load(self):
with open(self.txt_filelist, "r") as f:
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
self.synsets = [p.split("/")[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets]
else:
self.class_labels = [self.synset2idx[s] for s in self.synsets]
with open(self.human_dict, "r") as f:
human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets]
labels = {
"relpath": np.array(self.relpaths),
"synsets": np.array(self.synsets),
"class_label": np.array(self.class_labels),
"human_label": np.array(self.human_labels),
}
if self.process_images:
self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths(self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
else:
self.data = self.abspaths
class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
FILES = [
"ILSVRC2012_img_train.tar",
]
SIZES = [
147897477120,
]
def __init__(self, process_images=True, data_root=None, **kwargs):
self.process_images = process_images
self.data_root = data_root
super().__init__(**kwargs)
def _prepare(self):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
FILES = [
"ILSVRC2012_img_val.tar",
"validation_synset.txt",
]
SIZES = [
6744924160,
1950000,
]
def __init__(self, process_images=True, data_root=None, **kwargs):
self.data_root = data_root
self.process_images = process_images
super().__init__(**kwargs)
def _prepare(self):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders")
synsets = np.unique(list(synset_dict.values()))
for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True)
for k, v in synset_dict.items():
src = os.path.join(datadir, k)
dst = os.path.join(datadir, v)
shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetSR(Dataset):
def __init__(self, size=None,
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
random_crop=True):
"""
Imagenet Superresolution Dataloader
Performs following ops in order:
1. crops a crop of size s from image either as random or center crop
2. resizes crop to size with cv2.area_interpolation
3. degrades resized crop with degradation_fn
:param size: resizing to size after cropping
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
:param downscale_f: Low Resolution Downsample factor
:param min_crop_f: determines crop size s,
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
:param max_crop_f: ""
:param data_root:
:param random_crop:
"""
self.base = self.get_base()
assert size
assert (size / downscale_f).is_integer()
self.size = size
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
elif degradation == "bsrgan_light":
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
else:
interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
}[degradation]
self.pil_interpolation = degradation.startswith("pil_")
if self.pil_interpolation:
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
interpolation=interpolation_fn)
def __len__(self):
return len(self.base)
def __getitem__(self, i):
example = self.base[i]
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
image = self.cropper(image=image)["image"]
image = self.image_rescaler(image=image)["image"]
if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image)
LR_image = self.degradation_process(image_pil)
LR_image = np.array(LR_image).astype(np.uint8)
else:
LR_image = self.degradation_process(image=image)["image"]
example["image"] = (image/127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
return example
class ImageNetSRTrain(ImageNetSR):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,)
return Subset(dset, indices)
class ImageNetSRValidation(ImageNetSR):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,)
return Subset(dset, indices)

View File

@ -0,0 +1,92 @@
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class LSUNBase(Dataset):
def __init__(self,
txt_file,
data_root,
size=None,
interpolation="bicubic",
flip_p=0.5
):
self.data_paths = txt_file
self.data_root = data_root
with open(self.data_paths, "r") as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l)
for l in self.image_paths],
}
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
flip_p=flip_p, **kwargs)
class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
flip_p=flip_p, **kwargs)
class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
flip_p=flip_p, **kwargs)

View File

@ -0,0 +1,98 @@
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
self.last_f = f
return f

View File

@ -0,0 +1,544 @@
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
from_pretrained: str=None
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
from diffusers.modeling_utils import load_state_dict
if from_pretrained is not None:
state_dict = load_state_dict(from_pretrained)
self._load_pretrained_model(state_dict)
def _state_key_mapping(self, state_dict: dict):
import re
res_dict = {}
key_list = state_dict.keys()
key_str = " ".join(key_list)
up_block_pattern = re.compile('upsamplers')
p1 = re.compile('mid.block_[0-9]')
p2 = re.compile('decoder.up.[0-9]')
up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1)
for key_, val_ in state_dict.items():
key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\
.replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\
.replace('mid.attentions.0.key', 'mid.attn_1.k')\
.replace('mid.attentions.0.query', 'mid.attn_1.q') \
.replace('mid.attentions.0.value', 'mid.attn_1.v') \
.replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \
.replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\
.replace('upsamplers.0', 'upsample')\
.replace('downsamplers.0', 'downsample')\
.replace('conv_shortcut', 'nin_shortcut')\
.replace('conv_norm_out', 'norm_out')
mid_list = re.findall(p1, key_)
if len(mid_list) != 0:
mid_str = mid_list[0]
mid_id = int(mid_str[-1]) + 1
key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id))
up_list = re.findall(p2, key_)
if len(up_list) != 0:
up_str = up_list[0]
up_id = up_blocks_count - 1 -int(up_str[-1])
key_ = key_.replace(up_str, up_str[:-1] + str(up_id))
res_dict[key_] = val_
return res_dict
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
state_dict = self._state_key_mapping(state_dict)
model_state_dict = self.state_dict()
loaded_keys = [k for k in state_dict.keys()]
expected_keys = list(model_state_dict.keys())
original_loaded_keys = loaded_keys
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
def _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys
if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
ignore_mismatched_sizes,
)
error_msgs = self._load_state_dict_into_model(state_dict)
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
def _load_state_dict_into_model(self, state_dict):
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix=""):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(self)
return error_msgs
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@ -0,0 +1,267 @@
import os
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from copy import deepcopy
from einops import rearrange
from glob import glob
from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = {
'class_label': EncoderUNetModel,
'segmentation': UNetModel
}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(self,
diffusion_path,
num_classes,
ckpt_path=None,
pool='attention',
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.e-2,
log_steps=10,
monitor='val/loss',
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.log_steps = log_steps
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
else self.diffusion_model.cond_stage_key
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
if self.label_key not in __models__:
raise NotImplementedError()
self.load_classifier(ckpt_path, pool)
self.scheduler_config = scheduler_config
self.use_scheduler = self.scheduler_config is not None
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters():
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
model_config.out_channels = self.num_classes
if self.label_key == 'class_label':
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print('#####################################################################')
print(f'load from ckpt "{ckpt_path}"')
print('#####################################################################')
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
def get_x_noisy(self, x, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@torch.no_grad()
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
@torch.no_grad()
def get_conditioning(self, batch, k=None):
if k is None:
k = self.label_key
assert k is not None, 'Needs to provide label key'
targets = batch[k].to(self.device)
if self.label_key == 'segmentation':
targets = rearrange(targets, 'b h w c -> b c h w')
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
# save some memory
self.diffusion_model.model.to('cpu')
@torch.no_grad()
def write_logs(self, loss, logits, targets):
log_prefix = 'train' if self.training else 'val'
log = {}
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
)
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
)
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
loss = F.cross_entropy(logits, targets, reduction='none')
self.write_logs(loss.detach(), logits.detach(), targets.detach())
loss = loss.mean()
return loss, logits, x_noisy, targets
def training_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
return loss
def reset_noise_accs(self):
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
def on_validation_start(self):
self.reset_noise_accs()
@torch.no_grad()
def validation_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
return loss
def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [optimizer], scheduler
return optimizer
@torch.no_grad()
def log_images(self, batch, N=8, *args, **kwargs):
log = dict()
x = self.get_input(batch, self.diffusion_model.first_stage_key)
log['inputs'] = x
y = self.get_conditioning(batch)
if self.label_key == 'class_label':
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log['labels'] = y
if ismap(y):
log['labels'] = self.diffusion_model.to_rgb(y)
for step in range(self.log_steps):
current_time = step * self.log_time_interval
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
log[f'inputs@t{current_time}'] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = rearrange(pred, 'b h w c -> b c h w')
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
for key in log:
log[key] = log[key][:N]
return log

View File

@ -0,0 +1,240 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
return x_dec

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,236 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@ -0,0 +1,314 @@
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from torch.utils import checkpoint
try:
from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv
FlASH_AVAILABLE = True
except:
FlASH_AVAILABLE = False
USE_FLASH = False
def enable_flash_attention():
global USE_FLASH
USE_FLASH = True
if FlASH_AVAILABLE is False:
print("Please install flash attention to activate new attention kernel.\n" +
"Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'")
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
dim_head = q.shape[-1] / self.heads
if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \
dim_head <= 128 and (dim_head % 8) == 0:
# print("in flash")
if q.shape[1] == k.shape[1]:
out = self._flash_attention_qkv(q, k, v)
else:
out = self._flash_attention_q_kv(q, k, v)
else:
out = self._native_attention(q, k, v, self.heads, mask)
return self.to_out(out)
def _native_attention(self, q, k, v, h, mask):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
out = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', out, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return out
def _flash_attention_qkv(self, q, k, v):
qkv = torch.stack([q, k, v], dim=2)
b = qkv.shape[0]
n = qkv.shape[1]
qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads)
out = flash_attention_qkv(qkv, self.scale, b, n)
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
return out
def _flash_attention_q_kv(self, q, k, v):
kv = torch.stack([k, v], dim=2)
b = q.shape[0]
q_seqlen = q.shape[1]
kv_seqlen = kv.shape[1]
q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads)
kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads)
out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen)
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
return out
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint
def forward(self, x, context=None):
if self.use_checkpoint:
return checkpoint(self._forward, x, context)
else:
return self._forward(x, context)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, use_checkpoint=False):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
x = x.contiguous()
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = x.contiguous()
x = self.proj_out(x)
return x + x_in

View File

@ -0,0 +1,862 @@
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from ldm.util import instantiate_from_config
from ldm.modules.attention import LinearAttention
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
h_ = self.proj_out(h_)
return x+h_
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
class temb_module(nn.Module):
def __init__(self):
super().__init__()
pass
class Model(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
# self.temb = nn.Module()
self.temb = temb_module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch,
self.temb_ch),
torch.nn.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
# down = nn.Module()
down = Down_module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
# self.mid = nn.Module()
self.mid = Mid_module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
skip_in = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
if i_block == self.num_res_blocks:
skip_in = ch*in_ch_mult[i_level]
block.append(ResnetBlock(in_channels=block_in+skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
# up = nn.Module()
up = Up_module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x, t=None, context=None):
#assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Down_module(nn.Module):
def __init__(self):
super().__init__()
pass
class Up_module(nn.Module):
def __init__(self):
super().__init__()
pass
class Mid_module(nn.Module):
def __init__(self):
super().__init__()
pass
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
# down = nn.Module()
down = Down_module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
# self.mid = nn.Module()
self.mid = Mid_module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
# self.mid = nn.Module()
self.mid = Mid_module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
# up = nn.Module()
up = Up_module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
nn.Conv2d(2*in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)])
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1,2,3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
ch_mult=(2,2), dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class LatentRescaler(nn.Module):
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1)
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.conv_out = nn.Conv2d(mid_channels,
out_channels,
kernel_size=1,
)
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
x = self.conv_out(x)
return x
class MergedRescaleEncoder(nn.Module):
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True,
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
z_channels=intermediate_chn, double_z=False, resolution=resolution,
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
out_ch=None)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
def forward(self, x):
x = self.encoder(x)
x = self.rescaler(x)
return x
class MergedRescaleDecoder(nn.Module):
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
tmp_chn = z_channels*ch_mult[-1]
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
ch_mult=ch_mult, resolution=resolution, ch=ch)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
out_channels=tmp_chn, depth=rescale_module_depth)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size//in_size))+1
factor_up = 1.+ (out_size % in_size)
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
out_channels=in_channels)
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
attn_resolutions=[], in_channels=None, ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)])
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=4,
stride=2,
padding=1)
def forward(self, x, scale_factor=1.0):
if scale_factor==1.0:
return x
else:
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
return x
class FirstStagePostProcessor(nn.Module):
def __init__(self, ch_mult:list, in_channels,
pretrained_model:nn.Module=None,
reshape=False,
n_channels=None,
dropout=0.,
pretrained_config=None):
super().__init__()
if pretrained_config is None:
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model
else:
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config)
self.do_reshape = reshape
if n_channels is None:
n_channels = self.pretrained_model.encoder.ch
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
stride=1,padding=1)
blocks = []
downs = []
ch_in = n_channels
for m in ch_mult:
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
ch_in = m * n_channels
downs.append(Downsample(ch_in, with_conv=False))
self.model = nn.ModuleList(blocks)
self.downsampler = nn.ModuleList(downs)
def instantiate_pretrained(self, config):
model = instantiate_from_config(config)
self.pretrained_model = model.eval()
# self.pretrained_model.train = False
for param in self.pretrained_model.parameters():
param.requires_grad = False
@torch.no_grad()
def encode_with_pretrained(self,x):
c = self.pretrained_model.encode(x)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
return c
def forward(self,x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None)
z = downmodel(z)
if self.do_reshape:
z = rearrange(z,'b c h w -> b (h w) c')
return z

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,276 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import os
import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
if use_fp16:
return embedding.half()
else:
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels, precision=16):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
if precision == 16:
return GroupNorm16(16, channels)
else:
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm16(nn.GroupNorm):
def forward(self, x):
return super().forward(x.half()).type(x.dtype)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

View File

@ -0,0 +1,92 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@ -0,0 +1,76 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
else torch.tensor(-1,dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
#remove as '.'-character is not allowed in buffers
s_name = name.replace('.','')
self.m_name2s_name.update({name:s_name})
self.register_buffer(s_name,p.clone().detach().data)
self.collected_params = []
def forward(self,model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

View File

@ -0,0 +1,264 @@
import types
import torch
import torch.nn as nn
from functools import partial
import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
import kornia
from transformers.models.clip.modeling_clip import CLIPTextTransformer
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
def forward(self, batch, key=None):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
c = self.embedding(c)
return c
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, x):
return self(x)
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
return tokens
@torch.no_grad()
def encode(self, text):
tokens = self(text)
if not self.vq_interface:
return tokens
return None, None, [None, None, tokens]
def decode(self, text):
return text
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
def forward(self, text):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, text):
# output of length 77
return self(text)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
def forward(self,x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
class CLIPTextModelZero(CLIPTextModel):
config_class = CLIPTextConfig
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformerZero(config)
class CLIPTextTransformerZero(CLIPTextTransformer):
def _build_causal_attention_mask(self, bsz, seq_len):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask.half()
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_fp16=True):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
if use_fp16:
self.transformer = CLIPTextModelZero.from_pretrained(version)
else:
self.transformer = CLIPTextModel.from_pretrained(version)
# print(self.transformer.modules())
# print("check model dtyoe: {}, {}".format(self.tokenizer.dtype, self.transformer.dtype))
self.device = device
self.max_length = max_length
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
# tokens = batch_encoding["input_ids"].to(self.device)
tokens = batch_encoding["input_ids"].to(self.device)
# print("token type: {}".format(tokens.dtype))
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
self.normalize = normalize
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = clip.tokenize(text).to(self.device)
z = self.model.encode_text(tokens)
if self.normalize:
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
return z
def encode(self, text):
z = self(text)
if z.ndim==2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z
class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x))
if __name__ == "__main__":
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)

View File

@ -0,0 +1,50 @@
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
"""
import torch
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
except ImportError:
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
"""
Arguments:
qkv: (batch*seq, 3, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Return:
out: (total, nheads, headdim).
"""
max_s = seq_len
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
device=qkv.device)
out = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, 0.0,
softmax_scale=sm_scale, causal=False
)
return out
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
"""
Arguments:
q: (batch*seq, nheads, headdim)
kv: (batch*seq, 2, nheads, headdim)
batch_size: int.
seq_len: int.
sm_scale: float. The scaling of QK^T before applying softmax.
Return:
out: (total, nheads, headdim).
"""
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
return out

View File

@ -0,0 +1,2 @@
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light

View File

@ -0,0 +1,730 @@
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import numpy as np
import cv2
import torch
from functools import partial
import random
from scipy import ndimage
import scipy
import scipy.stats as ss
from scipy.interpolate import interp2d
from scipy.linalg import orth
import albumentations
import ldm.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
return cropped_big_k / cropped_big_k.sum()
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
k = k / np.sum(k)
return k
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf - 1) * 0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w - 1)
y1 = np.clip(y1, 0, h - 1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
def blur(x, k):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z - MU
ZZ_t = ZZ.transpose(0, 1, 3, 2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
h = h / sumh
return h
def fspecial_laplacian(alpha):
alpha = max([0, min([alpha, 1])])
h1 = alpha / (alpha + 1)
h2 = (1 - alpha) / (alpha + 1)
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
h = np.array(h)
return h
def fspecial(filter_type, *args, **kwargs):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
return fspecial_laplacian(*args, **kwargs)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def bicubic_degradation(x, sf=3):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
K = img + weight * residual
K = np.clip(K, 0, 1)
return soft_mask * K + (1 - soft_mask) * img
def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
else:
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
return img
def add_resize(img, sf=4):
rnum = np.random.rand()
if rnum > 0.8: # up
sf1 = random.uniform(1, 2)
elif rnum < 0.7: # down
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = np.clip(img, 0.0, 1.0)
return img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
def add_JPEG_noise(img):
quality_factor = random.randint(30, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
return lq, hq
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f'img size ({h1}X{w1}) is too small!')
hq = img.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_blur(img, sf=sf)
elif i == 2:
a, b = img.shape[1], img.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = np.clip(img, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
img = add_JPEG_noise(img)
elif i == 6:
# add processed camera sensor noise
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
return img, hq
# todo no isp_model?
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
image = add_blur(image, sf=sf)
elif i == 1:
image = add_blur(image, sf=sf)
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image":image}
return example
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
"""
This is an extended degradation model by combining
the degradation models of BSRGAN and Real-ESRGAN
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
use_shuffle: the degradation shuffle
use_sharp: sharpening the img
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f'img size ({h1}X{w1}) is too small!')
if use_sharp:
img = add_sharpening(img)
hq = img.copy()
if random.random() < shuffle_prob:
shuffle_order = random.sample(range(13), 13)
else:
shuffle_order = list(range(13))
# local shuffle for noise, JPEG is always the last one
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_resize(img, sf=sf)
elif i == 2:
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 3:
if random.random() < poisson_prob:
img = add_Poisson_noise(img)
elif i == 4:
if random.random() < speckle_prob:
img = add_speckle_noise(img)
elif i == 5:
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
elif i == 6:
img = add_JPEG_noise(img)
elif i == 7:
img = add_blur(img, sf=sf)
elif i == 8:
img = add_resize(img, sf=sf)
elif i == 9:
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
elif i == 10:
if random.random() < poisson_prob:
img = add_Poisson_noise(img)
elif i == 11:
if random.random() < speckle_prob:
img = add_speckle_noise(img)
elif i == 12:
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
else:
print('check the shuffle!')
# resize to desired size
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
interpolation=random.choice([1, 2, 3]))
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf, lq_patchsize)
return img, hq
if __name__ == '__main__':
print("hey")
img = util.imread_uint('utils/test.png', 3)
print(img)
img = util.uint2single(img)
print(img)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_lq = deg_fn(img)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + '.png')

View File

@ -0,0 +1,650 @@
# -*- coding: utf-8 -*-
import numpy as np
import cv2
import torch
from functools import partial
import random
from scipy import ndimage
import scipy
import scipy.stats as ss
from scipy.interpolate import interp2d
from scipy.linalg import orth
import albumentations
import ldm.modules.image_degradation.utils_image as util
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
def modcrop_np(img, sf):
'''
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
'''
w, h = img.shape[:2]
im = np.copy(img)
return im[:w - w % sf, :h - h % sf, ...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
return cropped_big_k / cropped_big_k.sum()
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
""" generate an anisotropic Gaussian kernel
Args:
ksize : e.g., 15, kernel size
theta : [0, pi], rotation angle range
l1 : [0.1,50], scaling of eigenvalues
l2 : [0.1,l1], scaling of eigenvalues
If l1 = l2, will get an isotropic Gaussian kernel.
Returns:
k : kernel
"""
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
D = np.array([[l1, 0], [0, l2]])
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
return k
def gm_blur_kernel(mean, cov, size=15):
center = size / 2.0 + 0.5
k = np.zeros([size, size])
for y in range(size):
for x in range(size):
cy = y - center + 1
cx = x - center + 1
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
k = k / np.sum(k)
return k
def shift_pixel(x, sf, upper_left=True):
"""shift pixel for super-resolution with different scale factors
Args:
x: WxHxC or WxH
sf: scale factor
upper_left: shift direction
"""
h, w = x.shape[:2]
shift = (sf - 1) * 0.5
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
if upper_left:
x1 = xv + shift
y1 = yv + shift
else:
x1 = xv - shift
y1 = yv - shift
x1 = np.clip(x1, 0, w - 1)
y1 = np.clip(y1, 0, h - 1)
if x.ndim == 2:
x = interp2d(xv, yv, x)(x1, y1)
if x.ndim == 3:
for i in range(x.shape[-1]):
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
return x
def blur(x, k):
'''
x: image, NxcxHxW
k: kernel, Nx1xhxw
'''
n, c = x.shape[:2]
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
k = k.repeat(1, c, 1, 1)
k = k.view(-1, 1, k.shape[2], k.shape[3])
x = x.view(1, -1, x.shape[2], x.shape[3])
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
x = x.view(n, c, x.shape[2], x.shape[3])
return x
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
""""
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
# Kai Zhang
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
# max_var = 2.5 * sf
"""
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
theta = np.random.rand() * np.pi # random theta
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
# Set COV matrix using Lambdas and Theta
LAMBDA = np.diag([lambda_1, lambda_2])
Q = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
SIGMA = Q @ LAMBDA @ Q.T
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
# Set expectation position (shifting kernel for aligned image)
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
MU = MU[None, None, :, None]
# Create meshgrid for Gaussian
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
Z = np.stack([X, Y], 2)[:, :, :, None]
# Calcualte Gaussian for every pixel of the kernel
ZZ = Z - MU
ZZ_t = ZZ.transpose(0, 1, 3, 2)
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
# shift the kernel so it will be centered
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
# Normalize the kernel and return
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
kernel = raw_kernel / np.sum(raw_kernel)
return kernel
def fspecial_gaussian(hsize, sigma):
hsize = [hsize, hsize]
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
std = sigma
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
arg = -(x * x + y * y) / (2 * std * std)
h = np.exp(arg)
h[h < scipy.finfo(float).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
h = h / sumh
return h
def fspecial_laplacian(alpha):
alpha = max([0, min([alpha, 1])])
h1 = alpha / (alpha + 1)
h2 = (1 - alpha) / (alpha + 1)
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
h = np.array(h)
return h
def fspecial(filter_type, *args, **kwargs):
'''
python code from:
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
'''
if filter_type == 'gaussian':
return fspecial_gaussian(*args, **kwargs)
if filter_type == 'laplacian':
return fspecial_laplacian(*args, **kwargs)
"""
# --------------------------------------------
# degradation models
# --------------------------------------------
"""
def bicubic_degradation(x, sf=3):
'''
Args:
x: HxWxC image, [0, 1]
sf: down-scale factor
Return:
bicubicly downsampled LR image
'''
x = util.imresize_np(x, scale=1 / sf)
return x
def srmd_degradation(x, k, sf=3):
''' blur + bicubic downsampling
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2018learning,
title={Learning a single convolutional super-resolution network for multiple degradations},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={3262--3271},
year={2018}
}
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
x = bicubic_degradation(x, sf=sf)
return x
def dpsr_degradation(x, k, sf=3):
''' bicubic downsampling + blur
Args:
x: HxWxC image, [0, 1]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
Reference:
@inproceedings{zhang2019deep,
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
pages={1671--1681},
year={2019}
}
'''
x = bicubic_degradation(x, sf=sf)
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
return x
def classical_degradation(x, k, sf=3):
''' blur + downsampling
Args:
x: HxWxC image, [0, 1]/[0, 255]
k: hxw, double
sf: down-scale factor
Return:
downsampled LR image
'''
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
st = 0
return x[st::sf, st::sf, ...]
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
"""USM sharpening. borrowed from real-ESRGAN
Input image: I; Blurry image: B.
1. K = I + weight * (I - B)
2. Mask = 1 if abs(I - B) > threshold, else: 0
3. Blur mask:
4. Out = Mask * K + (1 - Mask) * I
Args:
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
weight (float): Sharp weight. Default: 1.
radius (float): Kernel size of Gaussian blur. Default: 50.
threshold (int):
"""
if radius % 2 == 0:
radius += 1
blur = cv2.GaussianBlur(img, (radius, radius), 0)
residual = img - blur
mask = np.abs(residual) * 255 > threshold
mask = mask.astype('float32')
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
K = img + weight * residual
K = np.clip(K, 0, 1)
return soft_mask * K + (1 - soft_mask) * img
def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
wd2 = wd2/4
wd = wd/4
if random.random() < 0.5:
l1 = wd2 * random.random()
l2 = wd2 * random.random()
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
else:
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
return img
def add_resize(img, sf=4):
rnum = np.random.rand()
if rnum > 0.8: # up
sf1 = random.uniform(1, 2)
elif rnum < 0.7: # down
sf1 = random.uniform(0.5 / sf, 1)
else:
sf1 = 1.0
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
img = np.clip(img, 0.0, 1.0)
return img
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
# noise_level = random.randint(noise_level1, noise_level2)
# rnum = np.random.rand()
# if rnum > 0.6: # add color Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
# elif rnum < 0.4: # add grayscale Gaussian noise
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
# else: # add noise
# L = noise_level2 / 255.
# D = np.diag(np.random.rand(3))
# U = orth(np.random.rand(3, 3))
# conv = np.dot(np.dot(np.transpose(U), D), U)
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
# img = np.clip(img, 0.0, 1.0)
# return img
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
rnum = np.random.rand()
if rnum > 0.6: # add color Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4: # add grayscale Gaussian noise
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else: # add noise
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
noise_level = random.randint(noise_level1, noise_level2)
img = np.clip(img, 0.0, 1.0)
rnum = random.random()
if rnum > 0.6:
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
elif rnum < 0.4:
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
else:
L = noise_level2 / 255.
D = np.diag(np.random.rand(3))
U = orth(np.random.rand(3, 3))
conv = np.dot(np.dot(np.transpose(U), D), U)
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
img = np.clip(img, 0.0, 1.0)
return img
def add_Poisson_noise(img):
img = np.clip((img * 255.0).round(), 0, 255) / 255.
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
if random.random() < 0.5:
img = np.random.poisson(img * vals).astype(np.float32) / vals
else:
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
img += noise_gray[:, :, np.newaxis]
img = np.clip(img, 0.0, 1.0)
return img
def add_JPEG_noise(img):
quality_factor = random.randint(80, 95)
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
img = cv2.imdecode(encimg, 1)
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
return img
def random_crop(lq, hq, sf=4, lq_patchsize=64):
h, w = lq.shape[:2]
rnd_h = random.randint(0, h - lq_patchsize)
rnd_w = random.randint(0, w - lq_patchsize)
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
return lq, hq
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = img.shape[:2]
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = img.shape[:2]
if h < lq_patchsize * sf or w < lq_patchsize * sf:
raise ValueError(f'img size ({h1}X{w1}) is too small!')
hq = img.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
img = util.imresize_np(img, 1 / 2, True)
img = np.clip(img, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
img = add_blur(img, sf=sf)
elif i == 1:
img = add_blur(img, sf=sf)
elif i == 2:
a, b = img.shape[1], img.shape[0]
# downsample2
if random.random() < 0.75:
sf1 = random.uniform(1, 2 * sf)
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
img = img[0::sf, 0::sf, ...] # nearest downsampling
img = np.clip(img, 0.0, 1.0)
elif i == 3:
# downsample3
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
img = np.clip(img, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
img = add_JPEG_noise(img)
elif i == 6:
# add processed camera sensor noise
if random.random() < isp_prob and isp_model is not None:
with torch.no_grad():
img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
img = add_JPEG_noise(img)
# random crop
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
return img, hq
# todo no isp_model?
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
"""
This is the degradation model of BSRGAN from the paper
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
----------
sf: scale factor
isp_model: camera ISP model
Returns
-------
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
"""
image = util.uint2single(image)
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
sf_ori = sf
h1, w1 = image.shape[:2]
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
h, w = image.shape[:2]
hq = image.copy()
if sf == 4 and random.random() < scale2_prob: # downsample1
if np.random.rand() < 0.5:
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
image = util.imresize_np(image, 1 / 2, True)
image = np.clip(image, 0.0, 1.0)
sf = 2
shuffle_order = random.sample(range(7), 7)
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
if idx1 > idx2: # keep downsample3 last
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
for i in shuffle_order:
if i == 0:
image = add_blur(image, sf=sf)
# elif i == 1:
# image = add_blur(image, sf=sf)
if i == 0:
pass
elif i == 2:
a, b = image.shape[1], image.shape[0]
# downsample2
if random.random() < 0.8:
sf1 = random.uniform(1, 2 * sf)
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
interpolation=random.choice([1, 2, 3]))
else:
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
k_shifted = shift_pixel(k, sf)
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
image = image[0::sf, 0::sf, ...] # nearest downsampling
image = np.clip(image, 0.0, 1.0)
elif i == 3:
# downsample3
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
image = np.clip(image, 0.0, 1.0)
elif i == 4:
# add Gaussian noise
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
elif i == 5:
# add JPEG noise
if random.random() < jpeg_prob:
image = add_JPEG_noise(image)
#
# elif i == 6:
# # add processed camera sensor noise
# if random.random() < isp_prob and isp_model is not None:
# with torch.no_grad():
# img, hq = isp_model.forward(img.copy(), hq)
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
example = {"image": image}
return example
if __name__ == '__main__':
print("hey")
img = util.imread_uint('utils/test.png', 3)
img = img[:448, :448]
h = img.shape[0] // 4
print("resizing to", h)
sf = 4
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
for i in range(20):
print(i)
img_hq = img
img_lq = deg_fn(img)["image"]
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq)
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
interpolation=0)
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
util.imsave(img_concat, str(i) + '.png')

Some files were not shown because too many files have changed in this diff Show More