mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] removed duplicated tutorials (#1904)
parent
351f0f64e6
commit
cb7ec714c8
|
@ -2,7 +2,7 @@
|
|||
|
||||
## Prepare Dataset
|
||||
|
||||
We use CIFAR10 dataset in this example. The dataset will be downloaded to `./data` by default.
|
||||
We use CIFAR10 dataset in this example. The dataset will be downloaded to `./data` by default.
|
||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
||||
|
||||
```bash
|
||||
|
@ -14,4 +14,4 @@ export DATA=/path/to/data
|
|||
|
||||
```bash
|
||||
colossalai run --nproc_per_node 4 auto_parallel_demo.py
|
||||
```
|
||||
```
|
|
@ -1,26 +1,28 @@
|
|||
from pathlib import Path
|
||||
from colossalai.logging import get_dist_logger
|
||||
import colossalai
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from titans.utils import barrier_context
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_dataloader
|
||||
from torchvision import transforms
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet50
|
||||
from tqdm import tqdm
|
||||
from titans.utils import barrier_context
|
||||
|
||||
import colossalai
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
||||
from colossalai.utils import get_dataloader
|
||||
|
||||
DATA_ROOT = Path(os.environ.get('DATA', './data'))
|
||||
BATCH_SIZE = 1024
|
|
@ -1,27 +0,0 @@
|
|||
# Handson 1: Multi-dimensional Parallelism with Colossal-AI
|
||||
|
||||
|
||||
## Install Colossal-AI and other dependencies
|
||||
|
||||
```bash
|
||||
sh install.sh
|
||||
```
|
||||
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
|
||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
||||
|
||||
```bash
|
||||
export DATA=/path/to/data
|
||||
```
|
||||
|
||||
|
||||
## Run on 2*2 device mesh
|
||||
|
||||
Current configuration setting on `config.py` is TP=2, PP=2.
|
||||
|
||||
```bash
|
||||
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||
```
|
|
@ -1,36 +0,0 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
# hyperparameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 256
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 10
|
||||
WARMUP_EPOCHS = 3
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 512
|
||||
DEPTH = 4
|
||||
NUM_HEADS = 4
|
||||
MLP_RATIO = 2
|
||||
NUM_CLASSES = 1000
|
||||
CHECKPOINT = False
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||
|
||||
# parallel setting
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
parallel = dict(
|
||||
pipeline=2,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
# pipeline config
|
||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
|
@ -1,4 +0,0 @@
|
|||
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
pip install colossalai==0.1.10+torch1.12cu11.3 -f https://release.colossalai.org
|
||||
pip install titans
|
||||
colossalai check -i
|
|
@ -1,116 +0,0 @@
|
|||
import os
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.utils import is_using_pp, get_dataloader
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from titans.model.vit.vit import _create_vit_model
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
|
||||
|
||||
def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from torch
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
logger.info("initialized distributed environment", ranks=[0])
|
||||
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
use_pipeline = is_using_pp()
|
||||
|
||||
# create model
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
||||
patch_size=gpc.config.PATCH_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
depth=gpc.config.DEPTH,
|
||||
num_heads=gpc.config.NUM_HEADS,
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
num_classes=10,
|
||||
init_method='jax',
|
||||
checkpoint=gpc.config.CHECKPOINT)
|
||||
|
||||
if use_pipeline:
|
||||
pipelinable = PipelinableContext()
|
||||
with pipelinable:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
pipelinable.to_layer_list()
|
||||
pipelinable.policy = "uniform"
|
||||
model = pipelinable.partition(
|
||||
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
else:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
|
||||
# count number of parameters
|
||||
total_numel = 0
|
||||
for p in model.parameters():
|
||||
total_numel += p.numel()
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_stage = 0
|
||||
else:
|
||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
logger.info(
|
||||
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||
|
||||
# create dataloaders
|
||||
root = os.environ.get('DATA', '../data/cifar10')
|
||||
train_dataloader, test_dataloader = build_cifar(
|
||||
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(
|
||||
), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
# training
|
||||
engine.train()
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
||||
progress = tqdm(range(len(train_dataloader)), desc=description)
|
||||
else:
|
||||
progress = range(len(train_dataloader))
|
||||
for _ in progress:
|
||||
engine.zero_grad()
|
||||
engine.execute_schedule(data_iter, return_output_label=False)
|
||||
engine.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,20 +0,0 @@
|
|||
# Handson 2: Sequence Parallelism with BERT
|
||||
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
|
||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
||||
|
||||
```bash
|
||||
export DATA=/path/to/data
|
||||
```
|
||||
|
||||
|
||||
## Run on 2*2 device mesh
|
||||
|
||||
Current configuration setting on `config.py` is TP=2, PP=2.
|
||||
|
||||
```bash
|
||||
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||
```
|
|
@ -1,35 +0,0 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
# hyperparameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 256
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 10
|
||||
WARMUP_EPOCHS = 3
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 512
|
||||
DEPTH = 4
|
||||
NUM_HEADS = 4
|
||||
MLP_RATIO = 2
|
||||
NUM_CLASSES = 1000
|
||||
CHECKPOINT = False
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||
|
||||
# parallel setting
|
||||
TENSOR_PARALLEL_SIZE = 1
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
parallel = dict(
|
||||
tensor=dict(size=4, mode='sequence')
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
# pipeline config
|
||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
|
@ -1,116 +0,0 @@
|
|||
import os
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.utils import is_using_pp, get_dataloader
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from titans.model.vit.vit import _create_vit_model
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
|
||||
|
||||
def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from torch
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
logger.info("initialized distributed environment", ranks=[0])
|
||||
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
use_pipeline = is_using_pp()
|
||||
|
||||
# create model
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
||||
patch_size=gpc.config.PATCH_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
depth=gpc.config.DEPTH,
|
||||
num_heads=gpc.config.NUM_HEADS,
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
num_classes=10,
|
||||
init_method='jax',
|
||||
checkpoint=gpc.config.CHECKPOINT)
|
||||
|
||||
if use_pipeline:
|
||||
pipelinable = PipelinableContext()
|
||||
with pipelinable:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
pipelinable.to_layer_list()
|
||||
pipelinable.policy = "uniform"
|
||||
model = pipelinable.partition(
|
||||
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
else:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
|
||||
# count number of parameters
|
||||
total_numel = 0
|
||||
for p in model.parameters():
|
||||
total_numel += p.numel()
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_stage = 0
|
||||
else:
|
||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
logger.info(
|
||||
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||
|
||||
# create dataloaders
|
||||
root = os.environ.get('DATA', '../data/cifar10')
|
||||
train_dataloader, test_dataloader = build_cifar(
|
||||
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(
|
||||
), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
# training
|
||||
engine.train()
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
||||
progress = tqdm(range(len(train_dataloader)), desc=description)
|
||||
else:
|
||||
progress = range(len(train_dataloader))
|
||||
for _ in progress:
|
||||
engine.zero_grad()
|
||||
engine.execute_schedule(data_iter, return_output_label=False)
|
||||
engine.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,17 +0,0 @@
|
|||
# Handson 4: Comparison of Large Batch Training Optimization
|
||||
|
||||
## Prepare Dataset
|
||||
|
||||
We use CIFAR10 dataset in this example. The dataset will be downloaded to `../data` by default.
|
||||
If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
|
||||
|
||||
```bash
|
||||
export DATA=/path/to/data
|
||||
```
|
||||
|
||||
|
||||
## Run on 2*2 device mesh
|
||||
|
||||
```bash
|
||||
colossalai run --nproc_per_node 4 train.py --config config.py
|
||||
```
|
|
@ -1,36 +0,0 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
# hyperparameters
|
||||
# BATCH_SIZE is as per GPU
|
||||
# global batch size = BATCH_SIZE x data parallel size
|
||||
BATCH_SIZE = 512
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
NUM_EPOCHS = 10
|
||||
WARMUP_EPOCHS = 3
|
||||
|
||||
# model config
|
||||
IMG_SIZE = 224
|
||||
PATCH_SIZE = 16
|
||||
HIDDEN_SIZE = 512
|
||||
DEPTH = 4
|
||||
NUM_HEADS = 4
|
||||
MLP_RATIO = 2
|
||||
NUM_CLASSES = 1000
|
||||
CHECKPOINT = False
|
||||
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
|
||||
|
||||
# parallel setting
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
parallel = dict(
|
||||
pipeline=2,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.NAIVE)
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
# pipeline config
|
||||
NUM_MICRO_BATCHES = parallel['pipeline']
|
|
@ -1,117 +0,0 @@
|
|||
import os
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import Lars, Lamb
|
||||
from colossalai.utils import is_using_pp, get_dataloader
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from titans.model.vit.vit import _create_vit_model
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
|
||||
|
||||
def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from torch
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
logger.info("initialized distributed environment", ranks=[0])
|
||||
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
use_pipeline = is_using_pp()
|
||||
|
||||
# create model
|
||||
model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
|
||||
patch_size=gpc.config.PATCH_SIZE,
|
||||
hidden_size=gpc.config.HIDDEN_SIZE,
|
||||
depth=gpc.config.DEPTH,
|
||||
num_heads=gpc.config.NUM_HEADS,
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
num_classes=10,
|
||||
init_method='jax',
|
||||
checkpoint=gpc.config.CHECKPOINT)
|
||||
|
||||
if use_pipeline:
|
||||
pipelinable = PipelinableContext()
|
||||
with pipelinable:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
pipelinable.to_layer_list()
|
||||
pipelinable.policy = "uniform"
|
||||
model = pipelinable.partition(
|
||||
1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
else:
|
||||
model = _create_vit_model(**model_kwargs)
|
||||
|
||||
# count number of parameters
|
||||
total_numel = 0
|
||||
for p in model.parameters():
|
||||
total_numel += p.numel()
|
||||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_stage = 0
|
||||
else:
|
||||
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
logger.info(
|
||||
f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
|
||||
|
||||
# create dataloaders
|
||||
root = os.environ.get('DATA', '../data/cifar10')
|
||||
train_dataloader, test_dataloader = build_cifar(
|
||||
gpc.config.BATCH_SIZE, root, pad_if_needed=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE,
|
||||
weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
|
||||
# initialize
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
# training
|
||||
engine.train()
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS)
|
||||
progress = tqdm(range(len(train_dataloader)), desc=description)
|
||||
else:
|
||||
progress = range(len(train_dataloader))
|
||||
for _ in progress:
|
||||
engine.zero_grad()
|
||||
engine.execute_schedule(data_iter, return_output_label=False)
|
||||
engine.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1 +0,0 @@
|
|||
# Handson 5: Fine-tuning and Serving for OPT from Hugging Face
|
|
@ -1,77 +0,0 @@
|
|||
# Overview
|
||||
|
||||
This is an example showing how to run OPT generation. The OPT model is implemented using ColossalAI.
|
||||
|
||||
It supports tensor parallelism, batching and caching.
|
||||
|
||||
# How to run
|
||||
|
||||
Run OPT-125M:
|
||||
```shell
|
||||
python opt_fastapi.py opt-125m
|
||||
```
|
||||
|
||||
It will launch a HTTP server on `0.0.0.0:7070` by default and you can customize host and port. You can open `localhost:7070/docs` in your browser to see the openapi docs.
|
||||
|
||||
## Configure
|
||||
|
||||
### Configure model
|
||||
```shell
|
||||
python opt_fastapi.py <model>
|
||||
```
|
||||
Available models: opt-125m, opt-6.7b, opt-30b, opt-175b.
|
||||
|
||||
### Configure tensor parallelism
|
||||
```shell
|
||||
python opt_fastapi.py <model> --tp <TensorParallelismWorldSize>
|
||||
```
|
||||
The `<TensorParallelismWorldSize>` can be an integer in `[1, #GPUs]`. Default `1`.
|
||||
|
||||
### Configure checkpoint
|
||||
```shell
|
||||
python opt_fastapi.py <model> --checkpoint <CheckpointPath>
|
||||
```
|
||||
The `<CheckpointPath>` can be a file path or a directory path. If it's a directory path, all files under the directory will be loaded.
|
||||
|
||||
### Configure queue
|
||||
```shell
|
||||
python opt_fastapi.py <model> --queue_size <QueueSize>
|
||||
```
|
||||
The `<QueueSize>` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406).
|
||||
|
||||
### Configure bathcing
|
||||
```shell
|
||||
python opt_fastapi.py <model> --max_batch_size <MaxBatchSize>
|
||||
```
|
||||
The `<MaxBatchSize>` can be an integer in `[1, MAXINT]`. The engine will make batch whose size is less or equal to this value.
|
||||
|
||||
Note that the batch size is not always equal to `<MaxBatchSize>`, as some consecutive requests may not be batched.
|
||||
|
||||
### Configure caching
|
||||
```shell
|
||||
python opt_fastapi.py <model> --cache_size <CacheSize> --cache_list_size <CacheListSize>
|
||||
```
|
||||
This will cache `<CacheSize>` unique requests. And for each unique request, it cache `<CacheListSize>` different results. A random result will be returned if the cache is hit.
|
||||
|
||||
The `<CacheSize>` can be an integer in `[0, MAXINT]`. If it's `0`, cache won't be applied. The `<CacheListSize>` can be an integer in `[1, MAXINT]`.
|
||||
|
||||
### Other configurations
|
||||
```shell
|
||||
python opt_fastapi.py -h
|
||||
```
|
||||
|
||||
# How to benchmark
|
||||
```shell
|
||||
cd benchmark
|
||||
locust
|
||||
```
|
||||
|
||||
Then open the web interface link which is on your console.
|
||||
|
||||
# Pre-process pre-trained weights
|
||||
|
||||
## OPT-66B
|
||||
See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py).
|
||||
|
||||
## OPT-175B
|
||||
See [script/process-opt-175b](./script/process-opt-175b/).
|
|
@ -1,59 +0,0 @@
|
|||
import torch
|
||||
from typing import List, Deque, Tuple, Hashable, Any
|
||||
from energonai import BatchManager, SubmitEntry, TaskEntry
|
||||
|
||||
|
||||
class BatchManagerForGeneration(BatchManager):
|
||||
def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
def _left_padding(self, batch_inputs):
|
||||
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
|
||||
outputs = {'input_ids': [], 'attention_mask': []}
|
||||
for inputs in batch_inputs:
|
||||
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
|
||||
padding_len = max_len - len(input_ids)
|
||||
input_ids = [self.pad_token_id] * padding_len + input_ids
|
||||
attention_mask = [0] * padding_len + attention_mask
|
||||
outputs['input_ids'].append(input_ids)
|
||||
outputs['attention_mask'].append(attention_mask)
|
||||
for k in outputs:
|
||||
outputs[k] = torch.tensor(outputs[k])
|
||||
return outputs, max_len
|
||||
|
||||
@staticmethod
|
||||
def _make_batch_key(entry: SubmitEntry) -> tuple:
|
||||
data = entry.data
|
||||
return (data['top_k'], data['top_p'], data['temperature'])
|
||||
|
||||
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
|
||||
entry = q.popleft()
|
||||
uids = [entry.uid]
|
||||
batch = [entry.data]
|
||||
while len(batch) < self.max_batch_size:
|
||||
if len(q) == 0:
|
||||
break
|
||||
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
|
||||
break
|
||||
if q[0].data['max_tokens'] > entry.data['max_tokens']:
|
||||
break
|
||||
e = q.popleft()
|
||||
batch.append(e.data)
|
||||
uids.append(e.uid)
|
||||
inputs, max_len = self._left_padding(batch)
|
||||
trunc_lens = []
|
||||
for data in batch:
|
||||
trunc_lens.append(max_len + data['max_tokens'])
|
||||
inputs['top_k'] = entry.data['top_k']
|
||||
inputs['top_p'] = entry.data['top_p']
|
||||
inputs['temperature'] = entry.data['temperature']
|
||||
inputs['max_tokens'] = max_len + entry.data['max_tokens']
|
||||
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
|
||||
|
||||
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
|
||||
retval = []
|
||||
for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens):
|
||||
retval.append((uid, output[:trunc_len]))
|
||||
return retval
|
|
@ -1,15 +0,0 @@
|
|||
from locust import HttpUser, task
|
||||
from json import JSONDecodeError
|
||||
|
||||
|
||||
class GenerationUser(HttpUser):
|
||||
@task
|
||||
def generate(self):
|
||||
prompt = 'Question: What is the longest river on the earth? Answer:'
|
||||
for i in range(4, 9):
|
||||
data = {'max_tokens': 2**i, 'prompt': prompt}
|
||||
with self.client.post('/generation', json=data, catch_response=True) as response:
|
||||
if response.status_code in (200, 406):
|
||||
response.success()
|
||||
else:
|
||||
response.failure('Response wrong')
|
|
@ -1,64 +0,0 @@
|
|||
from collections import OrderedDict
|
||||
from threading import Lock
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Any, Hashable, Dict
|
||||
|
||||
|
||||
class MissCacheError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ListCache:
|
||||
def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None:
|
||||
"""Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied.
|
||||
When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed.
|
||||
|
||||
Args:
|
||||
cache_size (int): Max size for LRU cache.
|
||||
list_size (int): Value list size.
|
||||
fixed_keys (List[Hashable], optional): The keys which won't be removed. Defaults to [].
|
||||
"""
|
||||
self.cache_size = cache_size
|
||||
self.list_size = list_size
|
||||
self.cache: OrderedDict[Hashable, List[Any]] = OrderedDict()
|
||||
self.fixed_cache: Dict[Hashable, List[Any]] = {}
|
||||
for key in fixed_keys:
|
||||
self.fixed_cache[key] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def get(self, key: Hashable) -> List[Any]:
|
||||
with self.lock():
|
||||
if key in self.fixed_cache:
|
||||
l = self.fixed_cache[key]
|
||||
if len(l) >= self.list_size:
|
||||
return l
|
||||
elif key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
l = self.cache[key]
|
||||
if len(l) >= self.list_size:
|
||||
return l
|
||||
raise MissCacheError()
|
||||
|
||||
def add(self, key: Hashable, value: Any) -> None:
|
||||
with self.lock():
|
||||
if key in self.fixed_cache:
|
||||
l = self.fixed_cache[key]
|
||||
if len(l) < self.list_size and value not in l:
|
||||
l.append(value)
|
||||
elif key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
l = self.cache[key]
|
||||
if len(l) < self.list_size and value not in l:
|
||||
l.append(value)
|
||||
else:
|
||||
if len(self.cache) >= self.cache_size:
|
||||
self.cache.popitem(last=False)
|
||||
self.cache[key] = [value]
|
||||
|
||||
@contextmanager
|
||||
def lock(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
self._lock.release()
|
|
@ -1,123 +0,0 @@
|
|||
import argparse
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from energonai import QueueFullError, launch_engine
|
||||
from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
async def generate(data: GenerationTaskReq, request: Request):
|
||||
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
|
||||
key = (data.prompt, data.max_tokens)
|
||||
try:
|
||||
if cache is None:
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(data.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = data.max_tokens
|
||||
inputs['top_k'] = data.top_k
|
||||
inputs['top_p'] = data.top_p
|
||||
inputs['temperature'] = data.temperature
|
||||
try:
|
||||
uid = id(data)
|
||||
engine.submit(uid, inputs)
|
||||
output = await engine.wait(uid)
|
||||
output = tokenizer.decode(output, skip_special_tokens=True)
|
||||
if cache is not None:
|
||||
cache.add(key, output)
|
||||
except QueueFullError as e:
|
||||
raise HTTPException(status_code=406, detail=e.args[0])
|
||||
|
||||
return {'text': output}
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown(*_):
|
||||
engine.shutdown()
|
||||
server.should_exit = True
|
||||
server.force_exit = True
|
||||
await server.shutdown()
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
||||
server = uvicorn.Server(config=config)
|
||||
server.run()
|
|
@ -1,122 +0,0 @@
|
|||
import logging
|
||||
import argparse
|
||||
import random
|
||||
from torch import Tensor
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B
|
||||
from transformers import GPT2Tokenizer
|
||||
from energonai import launch_engine, QueueFullError
|
||||
from sanic import Sanic
|
||||
from sanic.request import Request
|
||||
from sanic.response import json
|
||||
from sanic_ext import validate, openapi
|
||||
from batch import BatchManagerForGeneration
|
||||
from cache import ListCache, MissCacheError
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_tokens: int = Field(gt=0, le=256, example=64)
|
||||
prompt: str = Field(
|
||||
min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:')
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
|
||||
|
||||
app = Sanic('opt')
|
||||
|
||||
|
||||
@app.post('/generation')
|
||||
@openapi.body(GenerationTaskReq)
|
||||
@validate(json=GenerationTaskReq)
|
||||
async def generate(request: Request, body: GenerationTaskReq):
|
||||
logger.info(f'{request.ip}:{request.port} - "{request.method} {request.path}" - {body}')
|
||||
key = (body.prompt, body.max_tokens)
|
||||
try:
|
||||
if cache is None:
|
||||
raise MissCacheError()
|
||||
outputs = cache.get(key)
|
||||
output = random.choice(outputs)
|
||||
logger.info('Cache hit')
|
||||
except MissCacheError:
|
||||
inputs = tokenizer(body.prompt, truncation=True, max_length=512)
|
||||
inputs['max_tokens'] = body.max_tokens
|
||||
inputs['top_k'] = body.top_k
|
||||
inputs['top_p'] = body.top_p
|
||||
inputs['temperature'] = body.temperature
|
||||
try:
|
||||
uid = id(body)
|
||||
engine.submit(uid, inputs)
|
||||
output = await engine.wait(uid)
|
||||
assert isinstance(output, Tensor)
|
||||
output = tokenizer.decode(output, skip_special_tokens=True)
|
||||
if cache is not None:
|
||||
cache.add(key, output)
|
||||
except QueueFullError as e:
|
||||
return json({'detail': e.args[0]}, status=406)
|
||||
|
||||
return json({'text': output})
|
||||
|
||||
|
||||
@app.after_server_stop
|
||||
def shutdown(*_):
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
def get_model_fn(model_name: str):
|
||||
model_map = {
|
||||
'opt-125m': opt_125M,
|
||||
'opt-6.7b': opt_6B,
|
||||
'opt-30b': opt_30B,
|
||||
'opt-175b': opt_175B
|
||||
}
|
||||
return model_map[model_name]
|
||||
|
||||
|
||||
def print_args(args: argparse.Namespace):
|
||||
print('\n==> Args:')
|
||||
for k, v in args.__dict__.items():
|
||||
print(f'{k} = {v}')
|
||||
|
||||
|
||||
FIXED_CACHE_KEYS = [
|
||||
('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64),
|
||||
('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64),
|
||||
("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64)
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b'])
|
||||
parser.add_argument('--tp', type=int, default=1)
|
||||
parser.add_argument('--master_host', default='localhost')
|
||||
parser.add_argument('--master_port', type=int, default=19990)
|
||||
parser.add_argument('--rpc_port', type=int, default=19980)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--pipe_size', type=int, default=1)
|
||||
parser.add_argument('--queue_size', type=int, default=0)
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
parser.add_argument('--checkpoint', default=None)
|
||||
parser.add_argument('--cache_size', type=int, default=0)
|
||||
parser.add_argument('--cache_list_size', type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
print_args(args)
|
||||
model_kwargs = {}
|
||||
if args.checkpoint is not None:
|
||||
model_kwargs['checkpoint'] = args.checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b')
|
||||
if args.cache_size > 0:
|
||||
cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS)
|
||||
else:
|
||||
cache = None
|
||||
engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model),
|
||||
batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size,
|
||||
pad_token_id=tokenizer.pad_token_id),
|
||||
pipe_size=args.pipe_size,
|
||||
queue_size=args.queue_size,
|
||||
**model_kwargs)
|
||||
app.run(args.http_host, args.http_port)
|
|
@ -1,8 +0,0 @@
|
|||
fastapi==0.85.1
|
||||
locust==2.11.0
|
||||
pydantic==1.10.2
|
||||
sanic==22.9.0
|
||||
sanic_ext==22.9.0
|
||||
torch>=1.10.0
|
||||
transformers==4.23.1
|
||||
uvicorn==0.19.0
|
|
@ -1,46 +0,0 @@
|
|||
# Process OPT-175B weights
|
||||
|
||||
You should download the pre-trained weights following the [doc](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT) before reading this.
|
||||
|
||||
First, install `metaseq` and `git clone https://github.com/facebookresearch/metaseq.git`.
|
||||
|
||||
Then, `cd metaseq`.
|
||||
|
||||
To consolidate checkpoints to eliminate FSDP:
|
||||
|
||||
```shell
|
||||
bash metaseq/scripts/reshard_mp_launch_no_slurm.sh <directory_where_all_the_shards_are>/checkpoint_last <output_dir>/ 8 1
|
||||
```
|
||||
|
||||
You will get 8 files in `<output_dir>`, and you should have the following checksums:
|
||||
```
|
||||
7e71cb65c4be784aa0b2889ac6039ee8 reshard-model_part-0-shard0.pt
|
||||
c8123da04f2c25a9026ea3224d5d5022 reshard-model_part-1-shard0.pt
|
||||
45e5d10896382e5bc4a7064fcafd2b1e reshard-model_part-2-shard0.pt
|
||||
abb7296c4d2fc17420b84ca74fc3ce64 reshard-model_part-3-shard0.pt
|
||||
05dcc7ac6046f4d3f90b3d1068e6da15 reshard-model_part-4-shard0.pt
|
||||
d24dd334019060ce1ee7e625fcf6b4bd reshard-model_part-5-shard0.pt
|
||||
fb1615ce0bbe89cc717f3e5079ee2655 reshard-model_part-6-shard0.pt
|
||||
2f3124432d2dbc6aebfca06be4b791c2 reshard-model_part-7-shard0.pt
|
||||
```
|
||||
|
||||
Copy `flat-meta.json` to `<output_dir>`.
|
||||
|
||||
Then cd to this dir, and we unflatten parameters.
|
||||
|
||||
```shell
|
||||
bash unflat.sh <output_dir>/ <new_output_dir>/
|
||||
```
|
||||
|
||||
Finally, you will get 8 files in `<new_output_dir>` with following checksums:
|
||||
```
|
||||
6169c59d014be95553c89ec01b8abb62 reshard-model_part-0.pt
|
||||
58868105da3d74a528a548fdb3a8cff6 reshard-model_part-1.pt
|
||||
69b255dc5a49d0eba9e4b60432cda90b reshard-model_part-2.pt
|
||||
002c052461ff9ffb0cdac3d5906f41f2 reshard-model_part-3.pt
|
||||
6d57f72909320d511ffd5f1c668b2beb reshard-model_part-4.pt
|
||||
93c8c4041cdc0c7907cc7afcf15cec2a reshard-model_part-5.pt
|
||||
5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt
|
||||
f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt
|
||||
```
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def load_json(path: str):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def parse_shape_info(flat_dir: str):
|
||||
data = load_json(os.path.join(flat_dir, 'shape.json'))
|
||||
flat_info = defaultdict(lambda: defaultdict(list))
|
||||
for k, shape in data.items():
|
||||
matched = re.match(r'decoder.layers.\d+', k)
|
||||
if matched is None:
|
||||
flat_key = 'flat_param_0'
|
||||
else:
|
||||
flat_key = f'{matched[0]}.flat_param_0'
|
||||
flat_info[flat_key]['names'].append(k)
|
||||
flat_info[flat_key]['shapes'].append(shape)
|
||||
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
|
||||
return flat_info
|
||||
|
||||
|
||||
def convert(flat_dir: str, output_dir: str, part: int):
|
||||
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
|
||||
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
|
||||
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
|
||||
flat_sd = torch.load(flat_path)
|
||||
print(f'Loaded flat state dict from {flat_path}')
|
||||
output_sd = {}
|
||||
for flat_key, param_meta in flat_meta.items():
|
||||
flat_param = flat_sd['model'][flat_key]
|
||||
assert sum(param_meta['numels']) == flat_param.numel(
|
||||
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
|
||||
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
|
||||
output_sd[name] = param.view(shape)
|
||||
|
||||
torch.save(output_sd, output_path)
|
||||
print(f'Saved unflat state dict to {output_path}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('flat_dir')
|
||||
parser.add_argument('output_dir')
|
||||
parser.add_argument('part', type=int)
|
||||
args = parser.parse_args()
|
||||
convert(args.flat_dir, args.output_dir, args.part)
|
File diff suppressed because one or more lines are too long
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env sh
|
||||
|
||||
for i in $(seq 0 7); do
|
||||
python convert_ckpt.py $1 $2 ${i} &
|
||||
done
|
||||
|
||||
wait $(jobs -p)
|
|
@ -1,55 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
from multiprocessing import Pool
|
||||
|
||||
# download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main
|
||||
# you can use whether wget or git lfs
|
||||
|
||||
path = "/path/to/your/ckpt"
|
||||
new_path = "/path/to/the/processed/ckpt/"
|
||||
|
||||
assert os.path.isdir(path)
|
||||
files = []
|
||||
for filename in os.listdir(path):
|
||||
filepath = os.path.join(path, filename)
|
||||
if os.path.isfile(filepath):
|
||||
files.append(filepath)
|
||||
|
||||
with Pool(14) as pool:
|
||||
ckpts = pool.map(torch.load, files)
|
||||
|
||||
restored = {}
|
||||
for ckpt in ckpts:
|
||||
for k,v in ckpt.items():
|
||||
if(k[0] == 'm'):
|
||||
k = k[6:]
|
||||
if(k == "lm_head.weight"):
|
||||
k = "head.dense.weight"
|
||||
if(k == "decoder.final_layer_norm.weight"):
|
||||
k = "decoder.layer_norm.weight"
|
||||
if(k == "decoder.final_layer_norm.bias"):
|
||||
k = "decoder.layer_norm.bias"
|
||||
restored[k] = v
|
||||
restored["decoder.version"] = "0.0"
|
||||
|
||||
|
||||
split_num = len(restored.keys()) // 60
|
||||
count = 0
|
||||
file_count = 1
|
||||
tmp = {}
|
||||
for k,v in restored.items():
|
||||
print(k)
|
||||
tmp[k] = v
|
||||
count = count + 1
|
||||
if(count == split_num):
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
file_count = file_count + 1
|
||||
count = 0
|
||||
tmp = {}
|
||||
|
||||
filename = str(file_count) + "-restored.pt"
|
||||
torch.save(tmp, os.path.join(new_path, filename))
|
||||
|
||||
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
# Train OPT model with Colossal-AI
|
||||
|
||||
## OPT
|
||||
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
|
||||
|
||||
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
|
||||
|
||||
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
|
||||
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
|
||||
|
||||
## Our Modifications
|
||||
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
|
||||
|
||||
## Quick Start
|
||||
You can launch training by using the following bash script
|
||||
|
||||
```bash
|
||||
bash ./run_clm.sh <batch-size-per-gpu> <mem-cap> <model> <gpu-num>
|
||||
```
|
||||
|
||||
- batch-size-per-gpu: number of samples fed to each GPU, default is 16
|
||||
- mem-cap: limit memory usage within a value in GB, default is 0 (no limit)
|
||||
- model: the size of the OPT model, default is `6.7b`. Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`, you can request
|
||||
the pretrained weights from [OPT weight downloading page](https://github.com/facebookresearch/metaseq/tree/main/projects/OPT).
|
||||
- gpu-num: the number of GPUs to use, default is 1.
|
||||
|
||||
## Remarkable Performance
|
||||
On a single GPU, Colossal-AI’s automatic strategy provides remarkable performance gains from the ZeRO Offloading strategy by Microsoft DeepSpeed.
|
||||
Users can experience up to a 40% speedup, at a variety of model scales. However, when using a traditional deep learning training framework like PyTorch, a single GPU can no longer support the training of models at such a scale.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/OPT.png" width=1000/>
|
||||
</p>
|
||||
|
||||
Adopting the distributed training strategy with 8 GPUs is as simple as adding a `-nprocs 8` to the training command of Colossal-AI!
|
||||
|
||||
More details about behind the scenes can be found on the corresponding [blog](https://medium.com/@yangyou_berkeley/colossal-ai-seamlessly-accelerates-large-models-at-low-costs-with-hugging-face-4d1a887e500d),
|
||||
and a detailed tutorial will be added in [Documentation](https://www.colossalai.org/docs/get_started/installation) very soon.
|
|
@ -1,21 +0,0 @@
|
|||
export BS=16
|
||||
export MEMCAP=0
|
||||
export MODEL="6.7b"
|
||||
export GPUNUM=1
|
||||
|
||||
for MODEL in "6.7b" "13b" "1.3b"
|
||||
do
|
||||
for GPUNUM in 8 1
|
||||
do
|
||||
for BS in 16 24 32 8
|
||||
do
|
||||
for MEMCAP in 0 40
|
||||
do
|
||||
pkill -9 torchrun
|
||||
pkill -9 python
|
||||
|
||||
bash ./run_clm.sh $BS $MEMCAP $MODEL $GPUNUM
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
|
@ -1,6 +0,0 @@
|
|||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
|
||||
zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(),
|
||||
tensor_placement_policy="auto",
|
||||
reuse_fp16_shard=True),
|
||||
optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384))
|
|
@ -1,32 +0,0 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class barrier_context():
|
||||
"""
|
||||
This context manager is used to allow one process to execute while blocking all
|
||||
other processes in the same process group. This is often useful when downloading is required
|
||||
as we only want to download in one process to prevent file corruption.
|
||||
Args:
|
||||
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
|
||||
parallel_mode (ParallelMode): the parallel mode corresponding to a process group
|
||||
Usage:
|
||||
with barrier_context():
|
||||
dataset = CIFAR10(root='./data', download=True)
|
||||
"""
|
||||
|
||||
def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL):
|
||||
# the class name is lowercase by convention
|
||||
current_rank = gpc.get_local_rank(parallel_mode=parallel_mode)
|
||||
self.should_block = current_rank != executor_rank
|
||||
self.group = gpc.get_group(parallel_mode=parallel_mode)
|
||||
|
||||
def __enter__(self):
|
||||
if self.should_block:
|
||||
dist.barrier(group=self.group)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
if not self.should_block:
|
||||
dist.barrier(group=self.group)
|
|
@ -1,6 +0,0 @@
|
|||
colossalai
|
||||
torch >= 1.8.1
|
||||
datasets >= 1.8.0
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
accelerate == 0.13.2
|
|
@ -1,596 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
|
||||
on a text file or a dataset without using HuggingFace Trainer.
|
||||
|
||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||
https://huggingface.co/models?filter=text-generation
|
||||
"""
|
||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
||||
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from itertools import chain
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.utils import set_seed
|
||||
from context import barrier_context
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import colossalai
|
||||
import transformers
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.utils import get_current_device, get_dataloader
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
GPT2Tokenizer,
|
||||
OPTForCausalLM,
|
||||
SchedulerType,
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
def get_time_stamp():
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument("--train_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A csv or a json file containing the training data.")
|
||||
parser.add_argument("--validation_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A csv or a json file containing the validation data.")
|
||||
parser.add_argument(
|
||||
"--validation_split_percentage",
|
||||
default=5,
|
||||
help="The percentage of the train set used as validation set in case there's no validation split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_eval_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the evaluation dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_type",
|
||||
type=SchedulerType,
|
||||
default="linear",
|
||||
help="The scheduler type to use.",
|
||||
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||
)
|
||||
parser.add_argument("--num_warmup_steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler.")
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model type to use if training from scratch.",
|
||||
choices=MODEL_TYPES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of"
|
||||
" this size for training. Default to the model max input length for single sentence inputs (take into"
|
||||
" account special tokens)."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessing_num_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of processes to use for the preprocessing.",
|
||||
)
|
||||
parser.add_argument("--overwrite_cache",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument("--no_keep_linebreaks",
|
||||
action="store_true",
|
||||
help="Do not keep line breaks when using TXT files.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_model_id",
|
||||
type=str,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.")
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
help="Whether to enable experiment trackers for logging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="all",
|
||||
help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
||||
"Only applicable when `--with_tracking` is passed."),
|
||||
)
|
||||
|
||||
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
|
||||
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if args.train_file is not None:
|
||||
extension = args.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
|
||||
if args.validation_file is not None:
|
||||
extension = args.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
||||
|
||||
if args.push_to_hub:
|
||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def colo_memory_cap(size_in_GB):
|
||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
if size_in_GB * (1024**3) < cuda_capacity:
|
||||
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
||||
print("Using {} GB of GPU memory".format(size_in_GB))
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch(config=dict())
|
||||
logger = get_dist_logger()
|
||||
is_main_process = dist.get_rank() == 0
|
||||
|
||||
if is_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
if args.mem_cap > 0:
|
||||
colo_memory_cap(args.mem_cap)
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}")
|
||||
|
||||
# Handle the repository creation
|
||||
with barrier_context():
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
logger.info("Start preparing dataset", ranks=[0])
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
split=f"train[:{args.validation_split_percentage}%]",
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
split=f"train[{args.validation_split_percentage}%:]",
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
dataset_args = {}
|
||||
if args.train_file is not None:
|
||||
data_files["train"] = args.train_file
|
||||
if args.validation_file is not None:
|
||||
data_files["validation"] = args.validation_file
|
||||
extension = args.train_file.split(".")[-1]
|
||||
if extension == "txt":
|
||||
extension = "text"
|
||||
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
|
||||
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
|
||||
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
||||
if "validation" not in raw_datasets.keys():
|
||||
raw_datasets["validation"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[:{args.validation_split_percentage}%]",
|
||||
**dataset_args,
|
||||
)
|
||||
raw_datasets["train"] = load_dataset(
|
||||
extension,
|
||||
data_files=data_files,
|
||||
split=f"train[{args.validation_split_percentage}%:]",
|
||||
**dataset_args,
|
||||
)
|
||||
logger.info("Dataset is prepared", ranks=[0])
|
||||
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
if args.config_name:
|
||||
config = AutoConfig.from_pretrained(args.config_name)
|
||||
elif args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
else:
|
||||
config = CONFIG_MAPPING[args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
logger.info("Model config has been created", ranks=[0])
|
||||
|
||||
if args.model_name_or_path == 'facebook/opt-13b':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
||||
else:
|
||||
print(f'load model from {args.model_name_or_path}')
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
||||
logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0])
|
||||
|
||||
if args.init_in_cpu:
|
||||
init_dev = torch.device('cpu')
|
||||
else:
|
||||
init_dev = get_current_device()
|
||||
|
||||
# build model
|
||||
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
|
||||
# currently, there has a bug in pretrained opt-13b
|
||||
# we can not import it until huggingface fix it
|
||||
logger.info("Train a new model from scratch", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
logger.info("Finetune a pre-trained model", ranks=[0])
|
||||
with ColoInitContext(device=init_dev):
|
||||
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
local_files_only=False)
|
||||
|
||||
# enable graident checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
PLACEMENT_POLICY = 'auto'
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f'using Colossal-AI version {cai_version}')
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
pg = ProcessGroup()
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
|
||||
logger.info(f'{model.__class__.__name__} has been created', ranks=[0])
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# First we tokenize all the texts.
|
||||
column_names = raw_datasets["train"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples[text_column_name])
|
||||
|
||||
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
if args.block_size is None:
|
||||
block_size = tokenizer.model_max_length
|
||||
if block_size > 1024:
|
||||
logger.warning(
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||
"Picking 1024 instead. You can change that default value by passing --block_size xxx.")
|
||||
block_size = 1024
|
||||
else:
|
||||
if args.block_size > tokenizer.model_max_length:
|
||||
logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
|
||||
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.")
|
||||
block_size = min(args.block_size, tokenizer.model_max_length)
|
||||
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
||||
def group_texts(examples):
|
||||
# Concatenate all texts.
|
||||
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
||||
# customize this part to your needs.
|
||||
if total_length >= block_size:
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# Split by chunks of max_len.
|
||||
result = {
|
||||
k: [t[i:i + block_size] for i in range(0, total_length, block_size)
|
||||
] for k, t in concatenated_examples.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
||||
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
||||
# to preprocess.
|
||||
#
|
||||
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
||||
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
||||
|
||||
with barrier_context(executor_rank=0, parallel_mode=ParallelMode.DATA):
|
||||
lm_datasets = tokenized_datasets.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
desc=f"Grouping texts in chunks of {block_size}",
|
||||
)
|
||||
|
||||
train_dataset = lm_datasets["train"]
|
||||
eval_dataset = lm_datasets["validation"]
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
# for index in random.sample(range(len(train_dataset)), 3):
|
||||
# logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = get_dataloader(train_dataset,
|
||||
shuffle=True,
|
||||
add_sampler=True,
|
||||
collate_fn=default_data_collator,
|
||||
batch_size=args.per_device_train_batch_size)
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
collate_fn=default_data_collator,
|
||||
batch_size=args.per_device_eval_batch_size)
|
||||
logger.info("Dataloaders have been created", ranks=[0])
|
||||
|
||||
# Optimizer
|
||||
# Split weights in two groups, one with weight decay and the other not.
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.per_device_train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
||||
|
||||
logger.info("***** Running training *****", ranks=[0])
|
||||
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
|
||||
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}", ranks=[0])
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not is_main_process)
|
||||
completed_steps = 0
|
||||
starting_epoch = 0
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
model.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs['loss']
|
||||
optimizer.backward(loss)
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
progress_bar.update(1)
|
||||
completed_steps += 1
|
||||
|
||||
global_step += 1
|
||||
logger.info("Global step {} finished".format(global_step + 1), ranks=[0])
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
|
||||
loss = outputs['loss'].unsqueeze(0)
|
||||
losses.append(loss)
|
||||
|
||||
losses = torch.cat(losses)
|
||||
losses = losses[:len(eval_dataset)]
|
||||
try:
|
||||
eval_loss = torch.mean(losses)
|
||||
perplexity = math.exp(eval_loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
logger.info(f"Epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}", ranks=[0])
|
||||
|
||||
if args.output_dir is not None:
|
||||
model_state = model.state_dict()
|
||||
if is_main_process:
|
||||
torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
|
||||
dist.barrier()
|
||||
# load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps))
|
||||
# model.load_state_dict(load_state, strict=False)
|
||||
|
||||
logger.info("Training finished", ranks=[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,22 +0,0 @@
|
|||
set -x
|
||||
export BS=${1:-16}
|
||||
export MEMCAP=${2:-0}
|
||||
export MODEL=${3:-"125m"}
|
||||
export GPUNUM=${4:-1}
|
||||
|
||||
# make directory for logs
|
||||
mkdir -p ./logs
|
||||
|
||||
export MODLE_PATH="facebook/opt-${MODEL}"
|
||||
|
||||
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
|
||||
torchrun \
|
||||
--nproc_per_node ${GPUNUM} \
|
||||
--master_port 19198 \
|
||||
run_clm.py \
|
||||
--dataset_name wikitext \
|
||||
--dataset_config_name wikitext-2-raw-v1 \
|
||||
--output_dir $PWD \
|
||||
--mem_cap ${MEMCAP} \
|
||||
--model_name_or_path ${MODLE_PATH} \
|
||||
--per_device_train_batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
|
|
@ -1,16 +0,0 @@
|
|||
## Overview
|
||||
This example shows how to use ColossalAI to run huggingface GPT training with Gemini and ZeRO DDP.
|
||||
|
||||
## GPT
|
||||
We use the huggingface transformers GPT2 model. The input data is randonly generated.
|
||||
|
||||
## Our Modifications
|
||||
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
|
||||
|
||||
## Quick Start
|
||||
You can launch training by using the following bash script
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
bash run.sh
|
||||
```
|
|
@ -1,3 +0,0 @@
|
|||
colossalai >= 0.1.10
|
||||
torch >= 1.8.1
|
||||
transformers >= 4.231
|
|
@ -1 +0,0 @@
|
|||
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log
|
|
@ -1,241 +0,0 @@
|
|||
from functools import partial
|
||||
from time import time
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument(
|
||||
"--tp_degree",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor Parallelism Degree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Placement Policy for Gemini.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
## Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
if param.process_group.tp_world_size() == 1:
|
||||
param.set_process_group(pg)
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
## Randomly Generated Data
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def gpt2_medium(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_xl(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_10b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
|
||||
|
||||
def get_gpu_mem():
|
||||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_mem_info(prefix=''):
|
||||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# set process group for all parameters
|
||||
param.set_process_group(pg)
|
||||
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
# keep the shape of the output from c_fc
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
|
||||
|
||||
# Gemini + ZeRO DDP
|
||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||
cai_version = colossalai.__version__
|
||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||
from colossalai.nn.parallel import GeminiDDP
|
||||
model = GeminiDDP(model,
|
||||
device=get_current_device(),
|
||||
placement_policy=placememt_policy,
|
||||
pin_memory=True,
|
||||
search_range_mb=32)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=True,
|
||||
init_device=GeminiManager.get_default_device(placememt_policy))
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
else:
|
||||
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
NUM_STEPS = 10
|
||||
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(get_mem_info(), ranks=[0])
|
||||
|
||||
# build GPT model
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = gpt2_medium(checkpoint=True)
|
||||
|
||||
numel = sum([p.numel() for p in model.parameters()])
|
||||
logger.info(f'Model numel: {numel}', ranks=[0])
|
||||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||||
|
||||
# Tensor Parallelism (TP)
|
||||
tensor_parallelize(model, pg)
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
model = gemini_zero_dpp(model, pg, args.placement)
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
# build optimizer
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
model.train()
|
||||
for n in range(NUM_STEPS):
|
||||
# we just use randomly generated data here
|
||||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||||
optimizer.zero_grad()
|
||||
start = time()
|
||||
outputs = model(input_ids, attn_mask)
|
||||
loss = criterion(outputs, input_ids)
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||
optimizer.backward(loss)
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||
optimizer.step()
|
||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||||
step_time = time() - start
|
||||
logger.info(
|
||||
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
|
||||
ranks=[0])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,82 +0,0 @@
|
|||
Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||
|
||||
CreativeML Open RAIL-M
|
||||
dated August 22, 2022
|
||||
|
||||
Section I: PREAMBLE
|
||||
|
||||
Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
|
||||
|
||||
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
|
||||
|
||||
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
|
||||
|
||||
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
||||
|
||||
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
|
||||
|
||||
NOW THEREFORE, You and Licensor agree as follows:
|
||||
|
||||
1. Definitions
|
||||
|
||||
- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
||||
- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
||||
- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
|
||||
- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
|
||||
- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
|
||||
- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
||||
- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
|
||||
- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
|
||||
- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
|
||||
- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
|
||||
- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
||||
- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
|
||||
|
||||
Section II: INTELLECTUAL PROPERTY RIGHTS
|
||||
|
||||
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
|
||||
|
||||
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
||||
|
||||
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
||||
Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
|
||||
You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
|
||||
You must cause any modified files to carry prominent notices stating that You changed the files;
|
||||
You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
||||
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
|
||||
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
|
||||
|
||||
Section IV: OTHER PROVISIONS
|
||||
|
||||
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
|
||||
8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
|
||||
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
|
||||
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
||||
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
||||
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
|
||||
|
||||
|
||||
Attachment A
|
||||
|
||||
Use Restrictions
|
||||
|
||||
You agree not to use the Model or Derivatives of the Model:
|
||||
- In any way that violates any applicable national, federal, state, local or international law or regulation;
|
||||
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
||||
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
|
||||
- To generate or disseminate personal identifiable information that can be used to harm an individual;
|
||||
- To defame, disparage or otherwise harass others;
|
||||
- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
|
||||
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
|
||||
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
||||
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
|
||||
- To provide medical advice and medical results interpretation;
|
||||
- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
|
|
@ -1,115 +0,0 @@
|
|||
# Handson 6: Acceleration of Stable Diffusion
|
||||
|
||||
*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
|
||||
fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
|
||||
|
||||
We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to exploit multiple optimization strategies
|
||||
, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs.
|
||||
|
||||
## Stable Diffusion
|
||||
[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion
|
||||
model.
|
||||
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
|
||||
Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
|
||||
this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
|
||||
|
||||
<p id="diffusion_train" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/diffusion_train.png" width=800/>
|
||||
</p>
|
||||
|
||||
[Stable Diffusion with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) provides **6.5x faster training and pretraining cost saving, the hardware cost of fine-tuning can be almost 7X cheaper** (from RTX3090/4090 24GB to RTX3050/2070 8GB).
|
||||
|
||||
<p id="diffusion_demo" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/diffusion_demo.png" width=800/>
|
||||
</p>
|
||||
|
||||
## Requirements
|
||||
A suitable [conda](https://conda.io/) environment named `ldm` can be created
|
||||
and activated with:
|
||||
|
||||
```
|
||||
conda env create -f environment.yaml
|
||||
conda activate ldm
|
||||
```
|
||||
|
||||
You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
|
||||
|
||||
```
|
||||
conda install pytorch torchvision -c pytorch
|
||||
pip install transformers==4.19.2 diffusers invisible-watermark
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website
|
||||
```
|
||||
pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
|
||||
```
|
||||
|
||||
### Install [Lightning](https://github.com/Lightning-AI/lightning)
|
||||
We use the Sep. 2022 version with commit id as `b04a7aa`.
|
||||
```
|
||||
git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
|
||||
pip install -r requirements.txt && pip install .
|
||||
```
|
||||
|
||||
> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future.
|
||||
|
||||
## Dataset
|
||||
The DataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
|
||||
you should the change the `data.file_path` in the `config/train_colossalai.yaml`
|
||||
|
||||
## Training
|
||||
|
||||
we provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml`, `train_ddp.yaml`
|
||||
|
||||
for example, you can run the training from colossalai by
|
||||
```
|
||||
python main.py --logdir /tmp -t --postfix test -b config/train_colossalai.yaml
|
||||
```
|
||||
|
||||
- you can change the `--logdir` the save the log information and the last checkpoint
|
||||
|
||||
### Training config
|
||||
you can change the trainging config in the yaml file
|
||||
|
||||
- accelerator: acceleratortype, default 'gpu'
|
||||
- devices: device number used for training, default 4
|
||||
- max_epochs: max training epochs
|
||||
- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
|
||||
|
||||
|
||||
## Comments
|
||||
|
||||
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
||||
, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch),
|
||||
[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion).
|
||||
Thanks for open-sourcing!
|
||||
|
||||
- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
|
||||
|
||||
- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch).
|
||||
|
||||
## BibTeX
|
||||
|
||||
```
|
||||
@article{bian2021colossal,
|
||||
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
|
||||
author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
|
||||
journal={arXiv preprint arXiv:2110.14883},
|
||||
year={2021}
|
||||
}
|
||||
@misc{rombach2021highresolution,
|
||||
title={High-Resolution Image Synthesis with Latent Diffusion Models},
|
||||
author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
|
||||
year={2021},
|
||||
eprint={2112.10752},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
@article{dao2022flashattention,
|
||||
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
|
||||
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
|
||||
journal={arXiv preprint arXiv:2205.14135},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -1,116 +0,0 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1.e-4 ]
|
||||
f_min: [ 1.e-10 ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
use_fp16: True
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 64
|
||||
wrap: False
|
||||
train:
|
||||
target: ldm.data.base.Txt2ImgIterableBaseDataset
|
||||
params:
|
||||
file_path: "/data/scratch/diffuser/laion_part0/"
|
||||
world_size: 1
|
||||
rank: 0
|
||||
|
||||
lightning:
|
||||
trainer:
|
||||
accelerator: 'gpu'
|
||||
devices: 4
|
||||
log_gpu_memory: all
|
||||
max_epochs: 2
|
||||
precision: 16
|
||||
auto_select_gpus: False
|
||||
strategy:
|
||||
target: pytorch_lightning.strategies.ColossalAIStrategy
|
||||
params:
|
||||
use_chunk: False
|
||||
enable_distributed_storage: True,
|
||||
placement_policy: cuda
|
||||
force_outputs_fp32: False
|
||||
|
||||
log_every_n_steps: 2
|
||||
logger: True
|
||||
default_root_dir: "/tmp/diff_log/"
|
||||
profiler: pytorch
|
||||
|
||||
logger_config:
|
||||
wandb:
|
||||
target: pytorch_lightning.loggers.WandbLogger
|
||||
params:
|
||||
name: nowname
|
||||
save_dir: "/tmp/diff_log/"
|
||||
offline: opt.debug
|
||||
id: nowname
|
|
@ -1,113 +0,0 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 32
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 100 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1.e-4 ]
|
||||
f_min: [ 1.e-10 ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
use_fp16: True
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 64
|
||||
wrap: False
|
||||
train:
|
||||
target: ldm.data.base.Txt2ImgIterableBaseDataset
|
||||
params:
|
||||
file_path: "/data/scratch/diffuser/laion_part0/"
|
||||
world_size: 1
|
||||
rank: 0
|
||||
|
||||
lightning:
|
||||
trainer:
|
||||
accelerator: 'gpu'
|
||||
devices: 4
|
||||
log_gpu_memory: all
|
||||
max_epochs: 2
|
||||
precision: 16
|
||||
auto_select_gpus: False
|
||||
strategy:
|
||||
target: pytorch_lightning.strategies.DDPStrategy
|
||||
params:
|
||||
find_unused_parameters: False
|
||||
log_every_n_steps: 2
|
||||
# max_steps: 6o
|
||||
logger: True
|
||||
default_root_dir: "/tmp/diff_log/"
|
||||
# profiler: pytorch
|
||||
|
||||
logger_config:
|
||||
wandb:
|
||||
target: pytorch_lightning.loggers.WandbLogger
|
||||
params:
|
||||
name: nowname
|
||||
save_dir: "/tmp/diff_log/"
|
||||
offline: opt.debug
|
||||
id: nowname
|
|
@ -1,121 +0,0 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 32
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
check_nan_inf: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1.e-4 ]
|
||||
f_min: [ 1.e-10 ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: False
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
use_fp16: True
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 32
|
||||
wrap: False
|
||||
train:
|
||||
target: ldm.data.pokemon.PokemonDataset
|
||||
# params:
|
||||
# file_path: "/data/scratch/diffuser/laion_part0/"
|
||||
# world_size: 1
|
||||
# rank: 0
|
||||
|
||||
lightning:
|
||||
trainer:
|
||||
accelerator: 'gpu'
|
||||
devices: 4
|
||||
log_gpu_memory: all
|
||||
max_epochs: 2
|
||||
precision: 16
|
||||
auto_select_gpus: False
|
||||
strategy:
|
||||
target: pytorch_lightning.strategies.ColossalAIStrategy
|
||||
params:
|
||||
use_chunk: False
|
||||
enable_distributed_storage: True,
|
||||
placement_policy: cuda
|
||||
force_outputs_fp32: False
|
||||
initial_scale: 65536
|
||||
min_scale: 1
|
||||
max_scale: 65536
|
||||
# max_scale: 4294967296
|
||||
|
||||
log_every_n_steps: 2
|
||||
logger: True
|
||||
default_root_dir: "/tmp/diff_log/"
|
||||
profiler: pytorch
|
||||
|
||||
logger_config:
|
||||
wandb:
|
||||
target: pytorch_lightning.loggers.WandbLogger
|
||||
params:
|
||||
name: nowname
|
||||
save_dir: "/tmp/diff_log/"
|
||||
offline: opt.debug
|
||||
id: nowname
|
|
@ -1,32 +0,0 @@
|
|||
name: ldm
|
||||
channels:
|
||||
- pytorch
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.12
|
||||
- pip=20.3
|
||||
- cudatoolkit=11.3
|
||||
- pytorch=1.11.0
|
||||
- torchvision=0.12.0
|
||||
- numpy=1.19.2
|
||||
- pip:
|
||||
- albumentations==0.4.3
|
||||
- diffusers
|
||||
- opencv-python==4.6.0.66
|
||||
- pudb==2019.2
|
||||
- invisible-watermark
|
||||
- imageio==2.9.0
|
||||
- imageio-ffmpeg==0.4.2
|
||||
- pytorch-lightning==1.4.2
|
||||
- omegaconf==2.1.1
|
||||
- test-tube>=0.7.5
|
||||
- streamlit>=0.73.1
|
||||
- einops==0.3.0
|
||||
- torch-fidelity==0.3.0
|
||||
- transformers==4.19.2
|
||||
- torchmetrics==0.6.0
|
||||
- kornia==0.6
|
||||
- prefetch_generator
|
||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- -e .
|
|
@ -1,75 +0,0 @@
|
|||
import math
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
'''
|
||||
Define an interface to make the IterableDatasets for text2img data chainable
|
||||
'''
|
||||
def __init__(self, file_path: str, rank, world_size):
|
||||
super().__init__()
|
||||
self.file_path = file_path
|
||||
self.folder_list = []
|
||||
self.file_list = []
|
||||
self.txt_list = []
|
||||
self.info = self._get_file_info(file_path)
|
||||
self.start = self.info['start']
|
||||
self.end = self.info['end']
|
||||
self.rank = rank
|
||||
|
||||
self.world_size = world_size
|
||||
# self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
|
||||
# self.iter_start = self.start + self.rank * self.per_worker
|
||||
# self.iter_end = min(self.iter_start + self.per_worker, self.end)
|
||||
# self.num_records = self.iter_end - self.iter_start
|
||||
# self.valid_ids = [i for i in range(self.iter_end)]
|
||||
self.num_records = self.end - self.start
|
||||
self.valid_ids = [i for i in range(self.end)]
|
||||
|
||||
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
||||
|
||||
def __len__(self):
|
||||
# return self.iter_end - self.iter_start
|
||||
return self.end - self.start
|
||||
|
||||
def __iter__(self):
|
||||
sample_iterator = self._sample_generator(self.start, self.end)
|
||||
# sample_iterator = self._sample_generator(self.iter_start, self.iter_end)
|
||||
return sample_iterator
|
||||
|
||||
def _sample_generator(self, start, end):
|
||||
for idx in range(start, end):
|
||||
file_name = self.file_list[idx]
|
||||
txt_name = self.txt_list[idx]
|
||||
f_ = open(txt_name, 'r')
|
||||
txt_ = f_.read()
|
||||
f_.close()
|
||||
image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = torch.from_numpy(image) / 255
|
||||
yield {"caption": txt_, "image":image}
|
||||
|
||||
|
||||
def _get_file_info(self, file_path):
|
||||
info = \
|
||||
{
|
||||
"start": 1,
|
||||
"end": 0,
|
||||
}
|
||||
self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i]
|
||||
for folder in self.folder_list:
|
||||
files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i]
|
||||
txts = [k.replace('jpg', 'txt') for k in files]
|
||||
self.file_list.extend(files)
|
||||
self.txt_list.extend(txts)
|
||||
info['end'] = len(self.file_list)
|
||||
# with open(file_path, 'r') as fin:
|
||||
# for _ in enumerate(fin):
|
||||
# info['end'] += 1
|
||||
# self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list]
|
||||
return info
|
|
@ -1,394 +0,0 @@
|
|||
import os, yaml, pickle, shutil, tarfile, glob
|
||||
import cv2
|
||||
import albumentations
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torchvision.transforms.functional as TF
|
||||
from omegaconf import OmegaConf
|
||||
from functools import partial
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
import taming.data.utils as tdu
|
||||
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
||||
from taming.data.imagenet import ImagePaths
|
||||
|
||||
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
||||
|
||||
|
||||
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
||||
with open(path_to_yaml) as f:
|
||||
di2s = yaml.load(f)
|
||||
return dict((v,k) for k,v in di2s.items())
|
||||
|
||||
|
||||
class ImageNetBase(Dataset):
|
||||
def __init__(self, config=None):
|
||||
self.config = config or OmegaConf.create()
|
||||
if not type(self.config)==dict:
|
||||
self.config = OmegaConf.to_container(self.config)
|
||||
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
||||
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
||||
self._prepare()
|
||||
self._prepare_synset_to_human()
|
||||
self._prepare_idx_to_synset()
|
||||
self._prepare_human_to_integer_label()
|
||||
self._load()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.data[i]
|
||||
|
||||
def _prepare(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _filter_relpaths(self, relpaths):
|
||||
ignore = set([
|
||||
"n06596364_9591.JPEG",
|
||||
])
|
||||
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
||||
if "sub_indices" in self.config:
|
||||
indices = str_to_indices(self.config["sub_indices"])
|
||||
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
||||
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
||||
files = []
|
||||
for rpath in relpaths:
|
||||
syn = rpath.split("/")[0]
|
||||
if syn in synsets:
|
||||
files.append(rpath)
|
||||
return files
|
||||
else:
|
||||
return relpaths
|
||||
|
||||
def _prepare_synset_to_human(self):
|
||||
SIZE = 2655750
|
||||
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
||||
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
||||
if (not os.path.exists(self.human_dict) or
|
||||
not os.path.getsize(self.human_dict)==SIZE):
|
||||
download(URL, self.human_dict)
|
||||
|
||||
def _prepare_idx_to_synset(self):
|
||||
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
||||
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
||||
if (not os.path.exists(self.idx2syn)):
|
||||
download(URL, self.idx2syn)
|
||||
|
||||
def _prepare_human_to_integer_label(self):
|
||||
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
||||
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
||||
if (not os.path.exists(self.human2integer)):
|
||||
download(URL, self.human2integer)
|
||||
with open(self.human2integer, "r") as f:
|
||||
lines = f.read().splitlines()
|
||||
assert len(lines) == 1000
|
||||
self.human2integer_dict = dict()
|
||||
for line in lines:
|
||||
value, key = line.split(":")
|
||||
self.human2integer_dict[key] = int(value)
|
||||
|
||||
def _load(self):
|
||||
with open(self.txt_filelist, "r") as f:
|
||||
self.relpaths = f.read().splitlines()
|
||||
l1 = len(self.relpaths)
|
||||
self.relpaths = self._filter_relpaths(self.relpaths)
|
||||
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
||||
|
||||
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
||||
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
||||
|
||||
unique_synsets = np.unique(self.synsets)
|
||||
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
||||
if not self.keep_orig_class_label:
|
||||
self.class_labels = [class_dict[s] for s in self.synsets]
|
||||
else:
|
||||
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
||||
|
||||
with open(self.human_dict, "r") as f:
|
||||
human_dict = f.read().splitlines()
|
||||
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
||||
|
||||
self.human_labels = [human_dict[s] for s in self.synsets]
|
||||
|
||||
labels = {
|
||||
"relpath": np.array(self.relpaths),
|
||||
"synsets": np.array(self.synsets),
|
||||
"class_label": np.array(self.class_labels),
|
||||
"human_label": np.array(self.human_labels),
|
||||
}
|
||||
|
||||
if self.process_images:
|
||||
self.size = retrieve(self.config, "size", default=256)
|
||||
self.data = ImagePaths(self.abspaths,
|
||||
labels=labels,
|
||||
size=self.size,
|
||||
random_crop=self.random_crop,
|
||||
)
|
||||
else:
|
||||
self.data = self.abspaths
|
||||
|
||||
|
||||
class ImageNetTrain(ImageNetBase):
|
||||
NAME = "ILSVRC2012_train"
|
||||
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
||||
FILES = [
|
||||
"ILSVRC2012_img_train.tar",
|
||||
]
|
||||
SIZES = [
|
||||
147897477120,
|
||||
]
|
||||
|
||||
def __init__(self, process_images=True, data_root=None, **kwargs):
|
||||
self.process_images = process_images
|
||||
self.data_root = data_root
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _prepare(self):
|
||||
if self.data_root:
|
||||
self.root = os.path.join(self.data_root, self.NAME)
|
||||
else:
|
||||
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||
|
||||
self.datadir = os.path.join(self.root, "data")
|
||||
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||
self.expected_length = 1281167
|
||||
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
||||
default=True)
|
||||
if not tdu.is_prepared(self.root):
|
||||
# prep
|
||||
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||
|
||||
datadir = self.datadir
|
||||
if not os.path.exists(datadir):
|
||||
path = os.path.join(self.root, self.FILES[0])
|
||||
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||
import academictorrents as at
|
||||
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||
assert atpath == path
|
||||
|
||||
print("Extracting {} to {}".format(path, datadir))
|
||||
os.makedirs(datadir, exist_ok=True)
|
||||
with tarfile.open(path, "r:") as tar:
|
||||
tar.extractall(path=datadir)
|
||||
|
||||
print("Extracting sub-tars.")
|
||||
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
||||
for subpath in tqdm(subpaths):
|
||||
subdir = subpath[:-len(".tar")]
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
with tarfile.open(subpath, "r:") as tar:
|
||||
tar.extractall(path=subdir)
|
||||
|
||||
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||
filelist = sorted(filelist)
|
||||
filelist = "\n".join(filelist)+"\n"
|
||||
with open(self.txt_filelist, "w") as f:
|
||||
f.write(filelist)
|
||||
|
||||
tdu.mark_prepared(self.root)
|
||||
|
||||
|
||||
class ImageNetValidation(ImageNetBase):
|
||||
NAME = "ILSVRC2012_validation"
|
||||
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
||||
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
||||
FILES = [
|
||||
"ILSVRC2012_img_val.tar",
|
||||
"validation_synset.txt",
|
||||
]
|
||||
SIZES = [
|
||||
6744924160,
|
||||
1950000,
|
||||
]
|
||||
|
||||
def __init__(self, process_images=True, data_root=None, **kwargs):
|
||||
self.data_root = data_root
|
||||
self.process_images = process_images
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _prepare(self):
|
||||
if self.data_root:
|
||||
self.root = os.path.join(self.data_root, self.NAME)
|
||||
else:
|
||||
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||
self.datadir = os.path.join(self.root, "data")
|
||||
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||
self.expected_length = 50000
|
||||
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
||||
default=False)
|
||||
if not tdu.is_prepared(self.root):
|
||||
# prep
|
||||
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||
|
||||
datadir = self.datadir
|
||||
if not os.path.exists(datadir):
|
||||
path = os.path.join(self.root, self.FILES[0])
|
||||
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||
import academictorrents as at
|
||||
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||
assert atpath == path
|
||||
|
||||
print("Extracting {} to {}".format(path, datadir))
|
||||
os.makedirs(datadir, exist_ok=True)
|
||||
with tarfile.open(path, "r:") as tar:
|
||||
tar.extractall(path=datadir)
|
||||
|
||||
vspath = os.path.join(self.root, self.FILES[1])
|
||||
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
||||
download(self.VS_URL, vspath)
|
||||
|
||||
with open(vspath, "r") as f:
|
||||
synset_dict = f.read().splitlines()
|
||||
synset_dict = dict(line.split() for line in synset_dict)
|
||||
|
||||
print("Reorganizing into synset folders")
|
||||
synsets = np.unique(list(synset_dict.values()))
|
||||
for s in synsets:
|
||||
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
||||
for k, v in synset_dict.items():
|
||||
src = os.path.join(datadir, k)
|
||||
dst = os.path.join(datadir, v)
|
||||
shutil.move(src, dst)
|
||||
|
||||
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||
filelist = sorted(filelist)
|
||||
filelist = "\n".join(filelist)+"\n"
|
||||
with open(self.txt_filelist, "w") as f:
|
||||
f.write(filelist)
|
||||
|
||||
tdu.mark_prepared(self.root)
|
||||
|
||||
|
||||
|
||||
class ImageNetSR(Dataset):
|
||||
def __init__(self, size=None,
|
||||
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
||||
random_crop=True):
|
||||
"""
|
||||
Imagenet Superresolution Dataloader
|
||||
Performs following ops in order:
|
||||
1. crops a crop of size s from image either as random or center crop
|
||||
2. resizes crop to size with cv2.area_interpolation
|
||||
3. degrades resized crop with degradation_fn
|
||||
|
||||
:param size: resizing to size after cropping
|
||||
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
||||
:param downscale_f: Low Resolution Downsample factor
|
||||
:param min_crop_f: determines crop size s,
|
||||
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
||||
:param max_crop_f: ""
|
||||
:param data_root:
|
||||
:param random_crop:
|
||||
"""
|
||||
self.base = self.get_base()
|
||||
assert size
|
||||
assert (size / downscale_f).is_integer()
|
||||
self.size = size
|
||||
self.LR_size = int(size / downscale_f)
|
||||
self.min_crop_f = min_crop_f
|
||||
self.max_crop_f = max_crop_f
|
||||
assert(max_crop_f <= 1.)
|
||||
self.center_crop = not random_crop
|
||||
|
||||
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
||||
|
||||
if degradation == "bsrgan":
|
||||
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
||||
|
||||
elif degradation == "bsrgan_light":
|
||||
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
||||
|
||||
else:
|
||||
interpolation_fn = {
|
||||
"cv_nearest": cv2.INTER_NEAREST,
|
||||
"cv_bilinear": cv2.INTER_LINEAR,
|
||||
"cv_bicubic": cv2.INTER_CUBIC,
|
||||
"cv_area": cv2.INTER_AREA,
|
||||
"cv_lanczos": cv2.INTER_LANCZOS4,
|
||||
"pil_nearest": PIL.Image.NEAREST,
|
||||
"pil_bilinear": PIL.Image.BILINEAR,
|
||||
"pil_bicubic": PIL.Image.BICUBIC,
|
||||
"pil_box": PIL.Image.BOX,
|
||||
"pil_hamming": PIL.Image.HAMMING,
|
||||
"pil_lanczos": PIL.Image.LANCZOS,
|
||||
}[degradation]
|
||||
|
||||
self.pil_interpolation = degradation.startswith("pil_")
|
||||
|
||||
if self.pil_interpolation:
|
||||
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
||||
|
||||
else:
|
||||
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
||||
interpolation=interpolation_fn)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.base)
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = self.base[i]
|
||||
image = Image.open(example["file_path_"])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
image = np.array(image).astype(np.uint8)
|
||||
|
||||
min_side_len = min(image.shape[:2])
|
||||
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
||||
crop_side_len = int(crop_side_len)
|
||||
|
||||
if self.center_crop:
|
||||
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
||||
|
||||
else:
|
||||
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
||||
|
||||
image = self.cropper(image=image)["image"]
|
||||
image = self.image_rescaler(image=image)["image"]
|
||||
|
||||
if self.pil_interpolation:
|
||||
image_pil = PIL.Image.fromarray(image)
|
||||
LR_image = self.degradation_process(image_pil)
|
||||
LR_image = np.array(LR_image).astype(np.uint8)
|
||||
|
||||
else:
|
||||
LR_image = self.degradation_process(image=image)["image"]
|
||||
|
||||
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
||||
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
||||
|
||||
return example
|
||||
|
||||
|
||||
class ImageNetSRTrain(ImageNetSR):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_base(self):
|
||||
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
||||
indices = pickle.load(f)
|
||||
dset = ImageNetTrain(process_images=False,)
|
||||
return Subset(dset, indices)
|
||||
|
||||
|
||||
class ImageNetSRValidation(ImageNetSR):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_base(self):
|
||||
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
||||
indices = pickle.load(f)
|
||||
dset = ImageNetValidation(process_images=False,)
|
||||
return Subset(dset, indices)
|
|
@ -1,92 +0,0 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class LSUNBase(Dataset):
|
||||
def __init__(self,
|
||||
txt_file,
|
||||
data_root,
|
||||
size=None,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5
|
||||
):
|
||||
self.data_paths = txt_file
|
||||
self.data_root = data_root
|
||||
with open(self.data_paths, "r") as f:
|
||||
self.image_paths = f.read().splitlines()
|
||||
self._length = len(self.image_paths)
|
||||
self.labels = {
|
||||
"relative_file_path_": [l for l in self.image_paths],
|
||||
"file_path_": [os.path.join(self.data_root, l)
|
||||
for l in self.image_paths],
|
||||
}
|
||||
|
||||
self.size = size
|
||||
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = dict((k, self.labels[k][i]) for k in self.labels)
|
||||
image = Image.open(example["file_path_"])
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = img.shape[0], img.shape[1]
|
||||
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||
(w - crop) // 2:(w + crop) // 2]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
if self.size is not None:
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||
return example
|
||||
|
||||
|
||||
class LSUNChurchesTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
||||
|
||||
|
||||
class LSUNChurchesValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0., **kwargs):
|
||||
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
||||
flip_p=flip_p, **kwargs)
|
||||
|
||||
|
||||
class LSUNBedroomsTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
||||
|
||||
|
||||
class LSUNBedroomsValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0.0, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
||||
flip_p=flip_p, **kwargs)
|
||||
|
||||
|
||||
class LSUNCatsTrain(LSUNBase):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
||||
|
||||
|
||||
class LSUNCatsValidation(LSUNBase):
|
||||
def __init__(self, flip_p=0., **kwargs):
|
||||
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
||||
flip_p=flip_p, **kwargs)
|
|
@ -1,98 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
note: use with a base_lr of 1.0
|
||||
"""
|
||||
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.lr_start = lr_start
|
||||
self.lr_min = lr_min
|
||||
self.lr_max = lr_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_lr = 0.
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||
t = min(t, 1.0)
|
||||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
||||
1 + np.cos(t * np.pi))
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n,**kwargs)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler2:
|
||||
"""
|
||||
supports repeated iterations, configurable via lists
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
||||
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.cycle_lengths = cycle_lengths
|
||||
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||
self.last_f = 0.
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def find_in_interval(self, n):
|
||||
interval = 0
|
||||
for cl in self.cum_cycles[1:]:
|
||||
if n <= cl:
|
||||
return interval
|
||||
interval += 1
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}")
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi))
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}")
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
||||
self.last_f = f
|
||||
return f
|
||||
|
|
@ -1,544 +0,0 @@
|
|||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
from_pretrained: str=None
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
from diffusers.modeling_utils import load_state_dict
|
||||
if from_pretrained is not None:
|
||||
state_dict = load_state_dict(from_pretrained)
|
||||
self._load_pretrained_model(state_dict)
|
||||
|
||||
def _state_key_mapping(self, state_dict: dict):
|
||||
import re
|
||||
res_dict = {}
|
||||
key_list = state_dict.keys()
|
||||
key_str = " ".join(key_list)
|
||||
up_block_pattern = re.compile('upsamplers')
|
||||
p1 = re.compile('mid.block_[0-9]')
|
||||
p2 = re.compile('decoder.up.[0-9]')
|
||||
up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1)
|
||||
for key_, val_ in state_dict.items():
|
||||
key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\
|
||||
.replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\
|
||||
.replace('mid.attentions.0.key', 'mid.attn_1.k')\
|
||||
.replace('mid.attentions.0.query', 'mid.attn_1.q') \
|
||||
.replace('mid.attentions.0.value', 'mid.attn_1.v') \
|
||||
.replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \
|
||||
.replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\
|
||||
.replace('upsamplers.0', 'upsample')\
|
||||
.replace('downsamplers.0', 'downsample')\
|
||||
.replace('conv_shortcut', 'nin_shortcut')\
|
||||
.replace('conv_norm_out', 'norm_out')
|
||||
|
||||
mid_list = re.findall(p1, key_)
|
||||
if len(mid_list) != 0:
|
||||
mid_str = mid_list[0]
|
||||
mid_id = int(mid_str[-1]) + 1
|
||||
key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id))
|
||||
|
||||
up_list = re.findall(p2, key_)
|
||||
if len(up_list) != 0:
|
||||
up_str = up_list[0]
|
||||
up_id = up_blocks_count - 1 -int(up_str[-1])
|
||||
key_ = key_.replace(up_str, up_str[:-1] + str(up_id))
|
||||
res_dict[key_] = val_
|
||||
return res_dict
|
||||
|
||||
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
|
||||
state_dict = self._state_key_mapping(state_dict)
|
||||
model_state_dict = self.state_dict()
|
||||
loaded_keys = [k for k in state_dict.keys()]
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
original_loaded_keys = loaded_keys
|
||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
if state_dict is not None:
|
||||
# Whole checkpoint
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
original_loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
)
|
||||
error_msgs = self._load_state_dict_into_model(state_dict)
|
||||
return missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||
|
||||
def _load_state_dict_into_model(self, state_dict):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: torch.nn.Module, prefix=""):
|
||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(self)
|
||||
|
||||
return error_msgs
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
|
@ -1,267 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn import functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from copy import deepcopy
|
||||
from einops import rearrange
|
||||
from glob import glob
|
||||
from natsort import natsorted
|
||||
|
||||
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
||||
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
||||
|
||||
__models__ = {
|
||||
'class_label': EncoderUNetModel,
|
||||
'segmentation': UNetModel
|
||||
}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
|
||||
def __init__(self,
|
||||
diffusion_path,
|
||||
num_classes,
|
||||
ckpt_path=None,
|
||||
pool='attention',
|
||||
label_key=None,
|
||||
diffusion_ckpt_path=None,
|
||||
scheduler_config=None,
|
||||
weight_decay=1.e-2,
|
||||
log_steps=10,
|
||||
monitor='val/loss',
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_classes = num_classes
|
||||
# get latest config of diffusion model
|
||||
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
||||
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||
self.load_diffusion()
|
||||
|
||||
self.monitor = monitor
|
||||
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
||||
self.log_steps = log_steps
|
||||
|
||||
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
||||
else self.diffusion_model.cond_stage_key
|
||||
|
||||
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
||||
|
||||
if self.label_key not in __models__:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.load_classifier(ckpt_path, pool)
|
||||
|
||||
self.scheduler_config = scheduler_config
|
||||
self.use_scheduler = self.scheduler_config is not None
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||
sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def load_diffusion(self):
|
||||
model = instantiate_from_config(self.diffusion_config)
|
||||
self.diffusion_model = model.eval()
|
||||
self.diffusion_model.train = disabled_train
|
||||
for param in self.diffusion_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def load_classifier(self, ckpt_path, pool):
|
||||
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
||||
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
||||
model_config.out_channels = self.num_classes
|
||||
if self.label_key == 'class_label':
|
||||
model_config.pool = pool
|
||||
|
||||
self.model = __models__[self.label_key](**model_config)
|
||||
if ckpt_path is not None:
|
||||
print('#####################################################################')
|
||||
print(f'load from ckpt "{ckpt_path}"')
|
||||
print('#####################################################################')
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_x_noisy(self, x, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x))
|
||||
continuous_sqrt_alpha_cumprod = None
|
||||
if self.diffusion_model.use_continuous_noise:
|
||||
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
||||
# todo: make sure t+1 is correct here
|
||||
|
||||
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
||||
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
||||
|
||||
def forward(self, x_noisy, t, *args, **kwargs):
|
||||
return self.model(x_noisy, t)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = rearrange(x, 'b h w c -> b c h w')
|
||||
x = x.to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def get_conditioning(self, batch, k=None):
|
||||
if k is None:
|
||||
k = self.label_key
|
||||
assert k is not None, 'Needs to provide label key'
|
||||
|
||||
targets = batch[k].to(self.device)
|
||||
|
||||
if self.label_key == 'segmentation':
|
||||
targets = rearrange(targets, 'b h w c -> b c h w')
|
||||
for down in range(self.numd):
|
||||
h, w = targets.shape[-2:]
|
||||
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
||||
|
||||
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||
|
||||
return targets
|
||||
|
||||
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
||||
_, top_ks = torch.topk(logits, k, dim=1)
|
||||
if reduction == "mean":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
elif reduction == "none":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# save some memory
|
||||
self.diffusion_model.model.to('cpu')
|
||||
|
||||
@torch.no_grad()
|
||||
def write_logs(self, loss, logits, targets):
|
||||
log_prefix = 'train' if self.training else 'val'
|
||||
log = {}
|
||||
log[f"{log_prefix}/loss"] = loss.mean()
|
||||
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction="mean"
|
||||
)
|
||||
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction="mean"
|
||||
)
|
||||
|
||||
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
||||
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
||||
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
||||
lr = self.optimizers().param_groups[0]['lr']
|
||||
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
||||
|
||||
def shared_step(self, batch, t=None):
|
||||
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
||||
targets = self.get_conditioning(batch)
|
||||
if targets.dim() == 4:
|
||||
targets = targets.argmax(dim=1)
|
||||
if t is None:
|
||||
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
||||
else:
|
||||
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
||||
x_noisy = self.get_x_noisy(x, t)
|
||||
logits = self(x_noisy, t)
|
||||
|
||||
loss = F.cross_entropy(logits, targets, reduction='none')
|
||||
|
||||
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||
|
||||
loss = loss.mean()
|
||||
return loss, logits, x_noisy, targets
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
return loss
|
||||
|
||||
def reset_noise_accs(self):
|
||||
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
||||
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
||||
|
||||
def on_validation_start(self):
|
||||
self.reset_noise_accs()
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
|
||||
for t in self.noisy_acc:
|
||||
_, logits, _, targets = self.shared_step(batch, t)
|
||||
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
||||
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
||||
|
||||
if self.use_scheduler:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
}]
|
||||
return [optimizer], scheduler
|
||||
|
||||
return optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||
log['inputs'] = x
|
||||
|
||||
y = self.get_conditioning(batch)
|
||||
|
||||
if self.label_key == 'class_label':
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
log['labels'] = y
|
||||
|
||||
if ismap(y):
|
||||
log['labels'] = self.diffusion_model.to_rgb(y)
|
||||
|
||||
for step in range(self.log_steps):
|
||||
current_time = step * self.log_time_interval
|
||||
|
||||
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||
|
||||
log[f'inputs@t{current_time}'] = x_noisy
|
||||
|
||||
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
||||
pred = rearrange(pred, 'b h w c -> b c h w')
|
||||
|
||||
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
||||
|
||||
for key in log:
|
||||
log[key] = log[key][:N]
|
||||
|
||||
return log
|
|
@ -1,240 +0,0 @@
|
|||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
||||
extract_into_tensor
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
img, pred_x0 = outs
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
return x_dec
|
File diff suppressed because it is too large
Load Diff
|
@ -1,236 +0,0 @@
|
|||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
|
@ -1,314 +0,0 @@
|
|||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from torch.utils import checkpoint
|
||||
|
||||
try:
|
||||
from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv
|
||||
FlASH_AVAILABLE = True
|
||||
except:
|
||||
FlASH_AVAILABLE = False
|
||||
|
||||
USE_FLASH = False
|
||||
|
||||
|
||||
def enable_flash_attention():
|
||||
global USE_FLASH
|
||||
USE_FLASH = True
|
||||
if FlASH_AVAILABLE is False:
|
||||
print("Please install flash attention to activate new attention kernel.\n" +
|
||||
"Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'")
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
dim_head = q.shape[-1] / self.heads
|
||||
|
||||
if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \
|
||||
dim_head <= 128 and (dim_head % 8) == 0:
|
||||
# print("in flash")
|
||||
if q.shape[1] == k.shape[1]:
|
||||
out = self._flash_attention_qkv(q, k, v)
|
||||
else:
|
||||
out = self._flash_attention_q_kv(q, k, v)
|
||||
else:
|
||||
out = self._native_attention(q, k, v, self.heads, mask)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def _native_attention(self, q, k, v, h, mask):
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
# attention, what we cannot get enough of
|
||||
out = sim.softmax(dim=-1)
|
||||
out = einsum('b i j, b j d -> b i d', out, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return out
|
||||
|
||||
def _flash_attention_qkv(self, q, k, v):
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
b = qkv.shape[0]
|
||||
n = qkv.shape[1]
|
||||
qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads)
|
||||
out = flash_attention_qkv(qkv, self.scale, b, n)
|
||||
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
|
||||
return out
|
||||
|
||||
def _flash_attention_q_kv(self, q, k, v):
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
b = q.shape[0]
|
||||
q_seqlen = q.shape[1]
|
||||
kv_seqlen = kv.shape[1]
|
||||
q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads)
|
||||
kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads)
|
||||
out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen)
|
||||
out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
|
||||
return out
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
|
||||
|
||||
if self.use_checkpoint:
|
||||
return checkpoint(self._forward, x, context)
|
||||
else:
|
||||
return self._forward(x, context)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None, use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint)
|
||||
for d in range(depth)]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
x = x.contiguous()
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
x = x.contiguous()
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
|
@ -1,862 +0,0 @@
|
|||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.modules.attention import LinearAttention
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
def __init__(self, in_channels):
|
||||
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla"):
|
||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
return LinAttnBlock(in_channels)
|
||||
|
||||
class temb_module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
# self.temb = nn.Module()
|
||||
self.temb = temb_module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
# down = nn.Module()
|
||||
down = Down_module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
# self.mid = nn.Module()
|
||||
self.mid = Mid_module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
skip_in = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch*in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
# up = nn.Module()
|
||||
up = Up_module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
class Down_module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
class Up_module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
class Mid_module(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
# down = nn.Module()
|
||||
down = Down_module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
# self.mid = nn.Module()
|
||||
self.mid = Mid_module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
attn_type="vanilla", **ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
# middle
|
||||
# self.mid = nn.Module()
|
||||
self.mid = Mid_module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
# up = nn.Module()
|
||||
up = Up_module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
nn.Conv2d(2*in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)])
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
if i in [1,2,3]:
|
||||
x = layer(x, None)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
||||
ch_mult=(2,2), dropout=0.0):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = in_channels
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.res_blocks = nn.ModuleList()
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
self.res_blocks.append(nn.ModuleList(res_block))
|
||||
if i_level != self.num_resolutions - 1:
|
||||
self.upsample_blocks.append(Upsample(block_in, True))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
h = x
|
||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.res_blocks[i_level][i_block](h, None)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class LatentRescaler(nn.Module):
|
||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
||||
super().__init__()
|
||||
# residual block, interpolate, residual block
|
||||
self.factor = factor
|
||||
self.conv_in = nn.Conv2d(in_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
self.attn = AttnBlock(mid_channels)
|
||||
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
|
||||
self.conv_out = nn.Conv2d(mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
for block in self.res_block1:
|
||||
x = block(x, None)
|
||||
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
||||
x = self.attn(x)
|
||||
for block in self.res_block2:
|
||||
x = block(x, None)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
||||
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
||||
super().__init__()
|
||||
intermediate_chn = ch * ch_mult[-1]
|
||||
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
||||
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
||||
out_ch=None)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.rescaler(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleDecoder(nn.Module):
|
||||
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
||||
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
||||
super().__init__()
|
||||
tmp_chn = z_channels*ch_mult[-1]
|
||||
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn, depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
||||
super().__init__()
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size//in_size))+1
|
||||
factor_up = 1.+ (out_size % in_size)
|
||||
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
||||
out_channels=in_channels)
|
||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
||||
attn_resolutions=[], in_channels=None, ch=in_channels,
|
||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Resize(nn.Module):
|
||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
||||
super().__init__()
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor==1.0:
|
||||
return x
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
||||
return x
|
||||
|
||||
class FirstStagePostProcessor(nn.Module):
|
||||
|
||||
def __init__(self, ch_mult:list, in_channels,
|
||||
pretrained_model:nn.Module=None,
|
||||
reshape=False,
|
||||
n_channels=None,
|
||||
dropout=0.,
|
||||
pretrained_config=None):
|
||||
super().__init__()
|
||||
if pretrained_config is None:
|
||||
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.pretrained_model = pretrained_model
|
||||
else:
|
||||
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.instantiate_pretrained(pretrained_config)
|
||||
|
||||
self.do_reshape = reshape
|
||||
|
||||
if n_channels is None:
|
||||
n_channels = self.pretrained_model.encoder.ch
|
||||
|
||||
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
|
||||
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
|
||||
stride=1,padding=1)
|
||||
|
||||
blocks = []
|
||||
downs = []
|
||||
ch_in = n_channels
|
||||
for m in ch_mult:
|
||||
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
|
||||
ch_in = m * n_channels
|
||||
downs.append(Downsample(ch_in, with_conv=False))
|
||||
|
||||
self.model = nn.ModuleList(blocks)
|
||||
self.downsampler = nn.ModuleList(downs)
|
||||
|
||||
|
||||
def instantiate_pretrained(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
self.pretrained_model = model.eval()
|
||||
# self.pretrained_model.train = False
|
||||
for param in self.pretrained_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_with_pretrained(self,x):
|
||||
c = self.pretrained_model.encode(x)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
return c
|
||||
|
||||
def forward(self,x):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||
z = submodel(z,temb=None)
|
||||
z = downmodel(z)
|
||||
|
||||
if self.do_reshape:
|
||||
z = rearrange(z,'b c h w -> b (h w) c')
|
||||
return z
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -1,276 +0,0 @@
|
|||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
if use_fp16:
|
||||
return embedding.half()
|
||||
else:
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels, precision=16):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
if precision == 16:
|
||||
return GroupNorm16(16, channels)
|
||||
else:
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
class GroupNorm16(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.half()).type(x.dtype)
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
|
@ -1,92 +0,0 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
|
@ -1,76 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
|
||||
else torch.tensor(-1,dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
#remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.','')
|
||||
self.m_name2s_name.update({name:s_name})
|
||||
self.register_buffer(s_name,p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def forward(self,model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
|
@ -1,264 +0,0 @@
|
|||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
import clip
|
||||
from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
|
||||
import kornia
|
||||
from transformers.models.clip.modeling_clip import CLIPTextTransformer
|
||||
|
||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
|
||||
def forward(self, batch, key=None):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
|
||||
class TransformerEmbedder(AbstractEncoder):
|
||||
"""Some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = tokens.to(self.device) # meh
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
class BERTTokenizer(AbstractEncoder):
|
||||
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||
super().__init__()
|
||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
self.device = device
|
||||
self.vq_interface = vq_interface
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
return tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, text):
|
||||
tokens = self(text)
|
||||
if not self.vq_interface:
|
||||
return tokens
|
||||
return None, None, [None, None, tokens]
|
||||
|
||||
def decode(self, text):
|
||||
return text
|
||||
|
||||
|
||||
class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
||||
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
if self.use_tknz_fn:
|
||||
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout)
|
||||
|
||||
def forward(self, text):
|
||||
if self.use_tknz_fn:
|
||||
tokens = self.tknz_fn(text)#.to(self.device)
|
||||
else:
|
||||
tokens = text
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
# output of length 77
|
||||
return self(text)
|
||||
|
||||
|
||||
class SpatialRescaler(nn.Module):
|
||||
def __init__(self,
|
||||
n_stages=1,
|
||||
method='bilinear',
|
||||
multiplier=0.5,
|
||||
in_channels=3,
|
||||
out_channels=None,
|
||||
bias=False):
|
||||
super().__init__()
|
||||
self.n_stages = n_stages
|
||||
assert self.n_stages >= 0
|
||||
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
||||
self.multiplier = multiplier
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None
|
||||
if self.remap_output:
|
||||
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
||||
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
||||
|
||||
def forward(self,x):
|
||||
for stage in range(self.n_stages):
|
||||
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||
|
||||
|
||||
if self.remap_output:
|
||||
x = self.channel_mapper(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
class CLIPTextModelZero(CLIPTextModel):
|
||||
config_class = CLIPTextConfig
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = CLIPTextTransformerZero(config)
|
||||
|
||||
class CLIPTextTransformerZero(CLIPTextTransformer):
|
||||
def _build_causal_attention_mask(self, bsz, seq_len):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(bsz, seq_len, seq_len)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask.half()
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_fp16=True):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
|
||||
if use_fp16:
|
||||
self.transformer = CLIPTextModelZero.from_pretrained(version)
|
||||
else:
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
|
||||
# print(self.transformer.modules())
|
||||
# print("check model dtyoe: {}, {}".format(self.tokenizer.dtype, self.transformer.dtype))
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
# tokens = batch_encoding["input_ids"].to(self.device)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
# print("token type: {}".format(tokens.dtype))
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPTextEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the CLIP transformer encoder for text.
|
||||
"""
|
||||
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(version, jit=False, device="cpu")
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.n_repeat = n_repeat
|
||||
self.normalize = normalize
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = clip.tokenize(text).to(self.device)
|
||||
z = self.model.encode_text(tokens)
|
||||
if self.normalize:
|
||||
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
z = self(text)
|
||||
if z.ndim==2:
|
||||
z = z[:, None, :]
|
||||
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
|
||||
return z
|
||||
|
||||
|
||||
class FrozenClipImageEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
jit=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
antialias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
return self.model.encode_image(self.preprocess(x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ldm.util import count_params
|
||||
model = FrozenCLIPEmbedder()
|
||||
count_params(model, verbose=True)
|
|
@ -1,50 +0,0 @@
|
|||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
|
||||
"""
|
||||
|
||||
import torch
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
|
||||
except ImportError:
|
||||
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
|
||||
|
||||
|
||||
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
|
||||
"""
|
||||
Arguments:
|
||||
qkv: (batch*seq, 3, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
"""
|
||||
max_s = seq_len
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
out = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_s, 0.0,
|
||||
softmax_scale=sm_scale, causal=False
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch*seq, nheads, headdim)
|
||||
kv: (batch*seq, 2, nheads, headdim)
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
"""
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
|
||||
out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
|
||||
return out
|
|
@ -1,2 +0,0 @@
|
|||
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
|
||||
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
|
|
@ -1,730 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
import random
|
||||
from scipy import ndimage
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
import albumentations
|
||||
|
||||
import ldm.modules.image_degradation.utils_image as util
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
'''
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
'''
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[:w - w % sf, :h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
""" generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
'''
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
'''
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
'''
|
||||
if filter_type == 'gaussian':
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == 'laplacian':
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
'''
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
'''
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
''' blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
''' bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
'''
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
''' blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
||||
else:
|
||||
k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(30, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
example = {"image":image}
|
||||
return example
|
||||
|
||||
|
||||
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
|
||||
def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
|
||||
"""
|
||||
This is an extended degradation model by combining
|
||||
the degradation models of BSRGAN and Real-ESRGAN
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
use_shuffle: the degradation shuffle
|
||||
use_sharp: sharpening the img
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
if use_sharp:
|
||||
img = add_sharpening(img)
|
||||
hq = img.copy()
|
||||
|
||||
if random.random() < shuffle_prob:
|
||||
shuffle_order = random.sample(range(13), 13)
|
||||
else:
|
||||
shuffle_order = list(range(13))
|
||||
# local shuffle for noise, JPEG is always the last one
|
||||
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
||||
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
||||
|
||||
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 1:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 2:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 3:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 4:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 5:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
elif i == 6:
|
||||
img = add_JPEG_noise(img)
|
||||
elif i == 7:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 8:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 9:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 10:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 11:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 12:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
else:
|
||||
print('check the shuffle!')
|
||||
|
||||
# resize to desired size
|
||||
img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("hey")
|
||||
img = util.imread_uint('utils/test.png', 3)
|
||||
print(img)
|
||||
img = util.uint2single(img)
|
||||
print(img)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_lq = deg_fn(img)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i) + '.png')
|
||||
|
||||
|
|
@ -1,650 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
import random
|
||||
from scipy import ndimage
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
import albumentations
|
||||
|
||||
import ldm.modules.image_degradation.utils_image as util
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
'''
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
'''
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[:w - w % sf, :h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
""" generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
'''
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
'''
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
|
||||
""""
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)],
|
||||
[np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
'''
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
'''
|
||||
if filter_type == 'gaussian':
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == 'laplacian':
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
'''
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
'''
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
''' blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
''' bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
'''
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
''' blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
'''
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype('float32')
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
|
||||
wd2 = wd2/4
|
||||
wd = wd/4
|
||||
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
|
||||
else:
|
||||
k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(80, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f'img size ({h1}X{w1}) is too small!')
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
hq = image.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
|
||||
|
||||
for i in shuffle_order:
|
||||
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
# elif i == 1:
|
||||
# image = add_blur(image, sf=sf)
|
||||
|
||||
if i == 0:
|
||||
pass
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.8:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]))
|
||||
else:
|
||||
k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
#
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
example = {"image": image}
|
||||
return example
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("hey")
|
||||
img = util.imread_uint('utils/test.png', 3)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_hq = img
|
||||
img_lq = deg_fn(img)["image"]
|
||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i) + '.png')
|
Binary file not shown.
Before Width: | Height: | Size: 431 KiB |
|
@ -1,916 +0,0 @@
|
|||
import os
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
from torchvision.utils import make_grid
|
||||
from datetime import datetime
|
||||
#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# Kai Zhang (github: https://github.com/cszn)
|
||||
# 03/Mar/2019
|
||||
# --------------------------------------------
|
||||
# https://github.com/twhui/SRGAN-pyTorch
|
||||
# https://github.com/xinntao/BasicSR
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime('%y%m%d-%H%M%S')
|
||||
|
||||
|
||||
def imshow(x, title=None, cbar=False, figsize=None):
|
||||
plt.figure(figsize=figsize)
|
||||
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
|
||||
if title:
|
||||
plt.title(title)
|
||||
if cbar:
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
def surf(Z, cmap='rainbow', figsize=None):
|
||||
plt.figure(figsize=figsize)
|
||||
ax3 = plt.axes(projection='3d')
|
||||
|
||||
w, h = Z.shape[:2]
|
||||
xx = np.arange(0,w,1)
|
||||
yy = np.arange(0,h,1)
|
||||
X, Y = np.meshgrid(xx, yy)
|
||||
ax3.plot_surface(X,Y,Z,cmap=cmap)
|
||||
#ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
||||
plt.show()
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# get image pathes
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def get_image_paths(dataroot):
|
||||
paths = None # return None if dataroot is None
|
||||
if dataroot is not None:
|
||||
paths = sorted(_get_paths_from_images(dataroot))
|
||||
return paths
|
||||
|
||||
|
||||
def _get_paths_from_images(path):
|
||||
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
||||
images = []
|
||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
||||
for fname in sorted(fnames):
|
||||
if is_image_file(fname):
|
||||
img_path = os.path.join(dirpath, fname)
|
||||
images.append(img_path)
|
||||
assert images, '{:s} has no valid image file'.format(path)
|
||||
return images
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# split large images into small images
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
||||
w, h = img.shape[:2]
|
||||
patches = []
|
||||
if w > p_max and h > p_max:
|
||||
w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
|
||||
h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
|
||||
w1.append(w-p_size)
|
||||
h1.append(h-p_size)
|
||||
# print(w1)
|
||||
# print(h1)
|
||||
for i in w1:
|
||||
for j in h1:
|
||||
patches.append(img[i:i+p_size, j:j+p_size,:])
|
||||
else:
|
||||
patches.append(img)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
def imssave(imgs, img_path):
|
||||
"""
|
||||
imgs: list, N images of size WxHxC
|
||||
"""
|
||||
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
|
||||
cv2.imwrite(new_path, img)
|
||||
|
||||
|
||||
def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
|
||||
"""
|
||||
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
||||
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
||||
will be splitted.
|
||||
Args:
|
||||
original_dataroot:
|
||||
taget_dataroot:
|
||||
p_size: size of small images
|
||||
p_overlap: patch size in training is a good choice
|
||||
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
||||
"""
|
||||
paths = get_image_paths(original_dataroot)
|
||||
for img_path in paths:
|
||||
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
img = imread_uint(img_path, n_channels=n_channels)
|
||||
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
||||
imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
|
||||
#if original_dataroot == taget_dataroot:
|
||||
#del img_path
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# makedir
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, str):
|
||||
mkdir(paths)
|
||||
else:
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
|
||||
|
||||
def mkdir_and_rename(path):
|
||||
if os.path.exists(path):
|
||||
new_name = path + '_archived_' + get_timestamp()
|
||||
print('Path already exists. Rename it to [{:s}]'.format(new_name))
|
||||
os.rename(path, new_name)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# read image from path
|
||||
# opencv is fast, but read BGR numpy image
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get uint8 image of size HxWxn_channles (RGB)
|
||||
# --------------------------------------------
|
||||
def imread_uint(path, n_channels=3):
|
||||
# input: path
|
||||
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
||||
if n_channels == 1:
|
||||
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
||||
img = np.expand_dims(img, axis=2) # HxWx1
|
||||
elif n_channels == 3:
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
||||
else:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
||||
return img
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# matlab's imwrite
|
||||
# --------------------------------------------
|
||||
def imsave(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
def imwrite(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get single image of size HxWxn_channles (BGR)
|
||||
# --------------------------------------------
|
||||
def read_img(path):
|
||||
# read image by cv2
|
||||
# return: Numpy float32, HWC, BGR, [0,1]
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
||||
img = img.astype(np.float32) / 255.
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
# some images have 4 channels
|
||||
if img.shape[2] > 3:
|
||||
img = img[:, :, :3]
|
||||
return img
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# image format conversion
|
||||
# --------------------------------------------
|
||||
# numpy(single) <---> numpy(unit)
|
||||
# numpy(single) <---> tensor
|
||||
# numpy(unit) <---> tensor
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) [0, 1] <---> numpy(unit)
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
def uint2single(img):
|
||||
|
||||
return np.float32(img/255.)
|
||||
|
||||
|
||||
def single2uint(img):
|
||||
|
||||
return np.uint8((img.clip(0, 1)*255.).round())
|
||||
|
||||
|
||||
def uint162single(img):
|
||||
|
||||
return np.float32(img/65535.)
|
||||
|
||||
|
||||
def single2uint16(img):
|
||||
|
||||
return np.uint16((img.clip(0, 1)*65535.).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(unit) (HxWxC or HxW) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert uint to 4-dimensional torch tensor
|
||||
def uint2tensor4(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
|
||||
|
||||
|
||||
# convert uint to 3-dimensional torch tensor
|
||||
def uint2tensor3(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
|
||||
|
||||
|
||||
# convert 2/3/4-dimensional torch tensor to uint
|
||||
def tensor2uint(img):
|
||||
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
return np.uint8((img*255.0).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) (HxWxC) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert single (HxWxC) to 3-dimensional torch tensor
|
||||
def single2tensor3(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
||||
|
||||
|
||||
# convert single (HxWxC) to 4-dimensional torch tensor
|
||||
def single2tensor4(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
||||
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
return img
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single3(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
elif img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return img
|
||||
|
||||
|
||||
def single2tensor5(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
||||
|
||||
|
||||
def single32tensor5(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
def single42tensor4(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
||||
|
||||
|
||||
# from skimage.io import imread, imsave
|
||||
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
||||
'''
|
||||
Converts a torch Tensor into an image Numpy array of BGR channel order
|
||||
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
||||
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
||||
'''
|
||||
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
||||
n_dim = tensor.dim()
|
||||
if n_dim == 4:
|
||||
n_img = len(tensor)
|
||||
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 3:
|
||||
img_np = tensor.numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 2:
|
||||
img_np = tensor.numpy()
|
||||
else:
|
||||
raise TypeError(
|
||||
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
||||
if out_type == np.uint8:
|
||||
img_np = (img_np * 255.0).round()
|
||||
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
||||
return img_np.astype(out_type)
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# Augmentation, flipe and/or rotate
|
||||
# --------------------------------------------
|
||||
# The following two are enough.
|
||||
# (1) augmet_img: numpy image of WxHxC or WxH
|
||||
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def augment_img(img, mode=0):
|
||||
'''Kai Zhang (github: https://github.com/cszn)
|
||||
'''
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return np.flipud(np.rot90(img))
|
||||
elif mode == 2:
|
||||
return np.flipud(img)
|
||||
elif mode == 3:
|
||||
return np.rot90(img, k=3)
|
||||
elif mode == 4:
|
||||
return np.flipud(np.rot90(img, k=2))
|
||||
elif mode == 5:
|
||||
return np.rot90(img)
|
||||
elif mode == 6:
|
||||
return np.rot90(img, k=2)
|
||||
elif mode == 7:
|
||||
return np.flipud(np.rot90(img, k=3))
|
||||
|
||||
|
||||
def augment_img_tensor4(img, mode=0):
|
||||
'''Kai Zhang (github: https://github.com/cszn)
|
||||
'''
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.rot90(1, [2, 3]).flip([2])
|
||||
elif mode == 2:
|
||||
return img.flip([2])
|
||||
elif mode == 3:
|
||||
return img.rot90(3, [2, 3])
|
||||
elif mode == 4:
|
||||
return img.rot90(2, [2, 3]).flip([2])
|
||||
elif mode == 5:
|
||||
return img.rot90(1, [2, 3])
|
||||
elif mode == 6:
|
||||
return img.rot90(2, [2, 3])
|
||||
elif mode == 7:
|
||||
return img.rot90(3, [2, 3]).flip([2])
|
||||
|
||||
|
||||
def augment_img_tensor(img, mode=0):
|
||||
'''Kai Zhang (github: https://github.com/cszn)
|
||||
'''
|
||||
img_size = img.size()
|
||||
img_np = img.data.cpu().numpy()
|
||||
if len(img_size) == 3:
|
||||
img_np = np.transpose(img_np, (1, 2, 0))
|
||||
elif len(img_size) == 4:
|
||||
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
||||
img_np = augment_img(img_np, mode=mode)
|
||||
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
||||
if len(img_size) == 3:
|
||||
img_tensor = img_tensor.permute(2, 0, 1)
|
||||
elif len(img_size) == 4:
|
||||
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
||||
|
||||
return img_tensor.type_as(img)
|
||||
|
||||
|
||||
def augment_img_np3(img, mode=0):
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.transpose(1, 0, 2)
|
||||
elif mode == 2:
|
||||
return img[::-1, :, :]
|
||||
elif mode == 3:
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 4:
|
||||
return img[:, ::-1, :]
|
||||
elif mode == 5:
|
||||
img = img[:, ::-1, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 6:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
return img
|
||||
elif mode == 7:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
|
||||
def augment_imgs(img_list, hflip=True, rot=True):
|
||||
# horizontal flip OR rotate
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
return [_augment(img) for img in img_list]
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# modcrop and shave
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def modcrop(img_in, scale):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
if img.ndim == 2:
|
||||
H, W = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[:H - H_r, :W - W_r]
|
||||
elif img.ndim == 3:
|
||||
H, W, C = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[:H - H_r, :W - W_r, :]
|
||||
else:
|
||||
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
||||
return img
|
||||
|
||||
|
||||
def shave(img_in, border=0):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
h, w = img.shape[:2]
|
||||
img = img[border:h-border, border:w-border]
|
||||
return img
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# image processing process on numpy image
|
||||
# channel_convert(in_c, tar_type, img_list):
|
||||
# rgb2ycbcr(img, only_y=True):
|
||||
# bgr2ycbcr(img, only_y=True):
|
||||
# ycbcr2rgb(img):
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
def rgb2ycbcr(img, only_y=True):
|
||||
'''same as matlab rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
||||
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
'''same as matlab ycbcr2rgb
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def bgr2ycbcr(img, only_y=True):
|
||||
'''bgr version of rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
'''
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def channel_convert(in_c, tar_type, img_list):
|
||||
# conversion among BGR, gray and y
|
||||
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
||||
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in gray_list]
|
||||
elif in_c == 3 and tar_type == 'y': # BGR to y
|
||||
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in y_list]
|
||||
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
||||
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
||||
else:
|
||||
return img_list
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# metric, PSNR and SSIM
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# PSNR
|
||||
# --------------------------------------------
|
||||
def calculate_psnr(img1, img2, border=0):
|
||||
# img1 and img2 have range [0, 255]
|
||||
#img1 = img1.squeeze()
|
||||
#img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border:h-border, border:w-border]
|
||||
img2 = img2[border:h-border, border:w-border]
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
mse = np.mean((img1 - img2)**2)
|
||||
if mse == 0:
|
||||
return float('inf')
|
||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# SSIM
|
||||
# --------------------------------------------
|
||||
def calculate_ssim(img1, img2, border=0):
|
||||
'''calculate SSIM
|
||||
the same outputs as MATLAB's
|
||||
img1, img2: [0, 255]
|
||||
'''
|
||||
#img1 = img1.squeeze()
|
||||
#img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError('Input images must have the same dimensions.')
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border:h-border, border:w-border]
|
||||
img2 = img2[border:h-border, border:w-border]
|
||||
|
||||
if img1.ndim == 2:
|
||||
return ssim(img1, img2)
|
||||
elif img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for i in range(3):
|
||||
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
|
||||
return np.array(ssims).mean()
|
||||
elif img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError('Wrong input image dimensions.')
|
||||
|
||||
|
||||
def ssim(img1, img2):
|
||||
C1 = (0.01 * 255)**2
|
||||
C2 = (0.03 * 255)**2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1**2
|
||||
mu2_sq = mu2**2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
||||
(sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
'''
|
||||
# --------------------------------------------
|
||||
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
||||
# --------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
# matlab 'imresize' function, now only support 'bicubic'
|
||||
def cubic(x):
|
||||
absx = torch.abs(x)
|
||||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
|
||||
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
if (scale < 1) and (antialiasing):
|
||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
||||
kernel_width = kernel_width / scale
|
||||
|
||||
# Output-space coordinates
|
||||
x = torch.linspace(1, out_length, out_length)
|
||||
|
||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
||||
# space maps to 1.5 in input space.
|
||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# What is the left-most pixel that can be involved in the computation?
|
||||
left = torch.floor(u - kernel_width / 2)
|
||||
|
||||
# What is the maximum number of pixels that can be involved in the
|
||||
# computation? Note: it's OK to use an extra pixel here; if the
|
||||
# corresponding weights are all zero, it will be eliminated at the end
|
||||
# of this function.
|
||||
P = math.ceil(kernel_width) + 2
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
||||
1, P).expand(out_length, P)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
||||
# apply cubic kernel
|
||||
if (scale < 1) and (antialiasing):
|
||||
weights = scale * cubic(distance_to_center * scale)
|
||||
else:
|
||||
weights = cubic(distance_to_center)
|
||||
# Normalize the weights matrix so that each row sums to 1.
|
||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||
weights = weights / weights_sum.expand(out_length, P)
|
||||
|
||||
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 1, P - 2)
|
||||
weights = weights.narrow(1, 1, P - 2)
|
||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 0, P - 2)
|
||||
weights = weights.narrow(1, 0, P - 2)
|
||||
weights = weights.contiguous()
|
||||
indices = indices.contiguous()
|
||||
sym_len_s = -indices.min() + 1
|
||||
sym_len_e = indices.max() - in_length
|
||||
indices = indices + sym_len_s - 1
|
||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for tensor image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: pytorch tensor, CHW or HW [0,1]
|
||||
# output: CHW or HW [0,1] w/o round
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(0)
|
||||
in_C, in_H, in_W = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
||||
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:, :sym_len_Hs, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[:, -sym_len_He:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
||||
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :, :sym_len_Ws]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, :, -sym_len_We:]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
return out_2
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for numpy image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize_np(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: Numpy, HWC or HW [0,1]
|
||||
# output: HWC or HW [0,1] w/o round
|
||||
img = torch.from_numpy(img)
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(2)
|
||||
|
||||
in_H, in_W, in_C = img.size()
|
||||
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
||||
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:sym_len_Hs, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[-sym_len_He:, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
||||
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :sym_len_Ws, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, -sym_len_We:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
|
||||
return out_2.numpy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('---')
|
||||
# img = imread_uint('test.bmp', 3)
|
||||
# img = uint2single(img)
|
||||
# img_bicubic = imresize_np(img, 1/4)
|
|
@ -1 +0,0 @@
|
|||
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
|
|
@ -1,111 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
||||
|
||||
|
||||
class LPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
||||
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
||||
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
||||
disc_loss="hinge"):
|
||||
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
self.kl_weight = kl_weight
|
||||
self.pixel_weight = pixelloss_weight
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
# output log variance
|
||||
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
||||
|
||||
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
||||
n_layers=disc_num_layers,
|
||||
use_actnorm=use_actnorm
|
||||
).apply(weights_init)
|
||||
self.discriminator_iter_start = disc_start
|
||||
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.disc_conditional = disc_conditional
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
|
||||
global_step, last_layer=None, cond=None, split="train",
|
||||
weights=None):
|
||||
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
|
||||
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
||||
weighted_nll_loss = nll_loss
|
||||
if weights is not None:
|
||||
weighted_nll_loss = weights*nll_loss
|
||||
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
||||
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
kl_loss = posteriors.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
if cond is None:
|
||||
assert not self.disc_conditional
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
else:
|
||||
assert self.disc_conditional
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
|
||||
if self.disc_factor > 0.0:
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
||||
except RuntimeError:
|
||||
assert not self.training
|
||||
d_weight = torch.tensor(0.0)
|
||||
else:
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
|
||||
|
||||
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
||||
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||
"{}/d_weight".format(split): d_weight.detach(),
|
||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
||||
}
|
||||
return loss, log
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
if cond is None:
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
else:
|
||||
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
|
||||
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
||||
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
||||
}
|
||||
return d_loss, log
|
||||
|
|
@ -1,167 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
||||
from taming.modules.losses.lpips import LPIPS
|
||||
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||
|
||||
|
||||
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
|
||||
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
|
||||
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
|
||||
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
|
||||
loss_real = (weights * loss_real).sum() / weights.sum()
|
||||
loss_fake = (weights * loss_fake).sum() / weights.sum()
|
||||
d_loss = 0.5 * (loss_real + loss_fake)
|
||||
return d_loss
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def measure_perplexity(predicted_indices, n_embed):
|
||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
|
||||
avg_probs = encodings.mean(0)
|
||||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
||||
cluster_use = torch.sum(avg_probs > 0)
|
||||
return perplexity, cluster_use
|
||||
|
||||
def l1(x, y):
|
||||
return torch.abs(x-y)
|
||||
|
||||
|
||||
def l2(x, y):
|
||||
return torch.pow((x-y), 2)
|
||||
|
||||
|
||||
class VQLPIPSWithDiscriminator(nn.Module):
|
||||
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
||||
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
||||
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
||||
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
|
||||
pixel_loss="l1"):
|
||||
super().__init__()
|
||||
assert disc_loss in ["hinge", "vanilla"]
|
||||
assert perceptual_loss in ["lpips", "clips", "dists"]
|
||||
assert pixel_loss in ["l1", "l2"]
|
||||
self.codebook_weight = codebook_weight
|
||||
self.pixel_weight = pixelloss_weight
|
||||
if perceptual_loss == "lpips":
|
||||
print(f"{self.__class__.__name__}: Running with LPIPS.")
|
||||
self.perceptual_loss = LPIPS().eval()
|
||||
else:
|
||||
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
|
||||
self.perceptual_weight = perceptual_weight
|
||||
|
||||
if pixel_loss == "l1":
|
||||
self.pixel_loss = l1
|
||||
else:
|
||||
self.pixel_loss = l2
|
||||
|
||||
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
||||
n_layers=disc_num_layers,
|
||||
use_actnorm=use_actnorm,
|
||||
ndf=disc_ndf
|
||||
).apply(weights_init)
|
||||
self.discriminator_iter_start = disc_start
|
||||
if disc_loss == "hinge":
|
||||
self.disc_loss = hinge_d_loss
|
||||
elif disc_loss == "vanilla":
|
||||
self.disc_loss = vanilla_d_loss
|
||||
else:
|
||||
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
||||
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
||||
self.disc_factor = disc_factor
|
||||
self.discriminator_weight = disc_weight
|
||||
self.disc_conditional = disc_conditional
|
||||
self.n_classes = n_classes
|
||||
|
||||
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
||||
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||
d_weight = d_weight * self.discriminator_weight
|
||||
return d_weight
|
||||
|
||||
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
||||
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
|
||||
if not exists(codebook_loss):
|
||||
codebook_loss = torch.tensor([0.]).to(inputs.device)
|
||||
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
if self.perceptual_weight > 0:
|
||||
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
else:
|
||||
p_loss = torch.tensor([0.0])
|
||||
|
||||
nll_loss = rec_loss
|
||||
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||
nll_loss = torch.mean(nll_loss)
|
||||
|
||||
# now the GAN part
|
||||
if optimizer_idx == 0:
|
||||
# generator update
|
||||
if cond is None:
|
||||
assert not self.disc_conditional
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
else:
|
||||
assert self.disc_conditional
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
|
||||
try:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
||||
except RuntimeError:
|
||||
assert not self.training
|
||||
d_weight = torch.tensor(0.0)
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
||||
|
||||
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
||||
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||
"{}/p_loss".format(split): p_loss.detach().mean(),
|
||||
"{}/d_weight".format(split): d_weight.detach(),
|
||||
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||
"{}/g_loss".format(split): g_loss.detach().mean(),
|
||||
}
|
||||
if predicted_indices is not None:
|
||||
assert self.n_classes is not None
|
||||
with torch.no_grad():
|
||||
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
|
||||
log[f"{split}/perplexity"] = perplexity
|
||||
log[f"{split}/cluster_usage"] = cluster_usage
|
||||
return loss, log
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# second pass for discriminator update
|
||||
if cond is None:
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||
else:
|
||||
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
||||
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
||||
|
||||
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||
|
||||
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
||||
"{}/logits_real".format(split): logits_real.detach().mean(),
|
||||
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
||||
}
|
||||
return d_loss, log
|
|
@ -1,641 +0,0 @@
|
|||
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
from collections import namedtuple
|
||||
from einops import rearrange, repeat, reduce
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple('Intermediates', [
|
||||
'pre_softmax_attn',
|
||||
'post_softmax_attn'
|
||||
])
|
||||
|
||||
LayerIntermediates = namedtuple('Intermediates', [
|
||||
'hiddens',
|
||||
'attn_intermediates'
|
||||
])
|
||||
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
self.init_()
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.emb.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.arange(x.shape[1], device=x.device)
|
||||
return self.emb(n)[None, :, :]
|
||||
|
||||
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x, seq_dim=1, offset=0):
|
||||
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
||||
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
return emb[None, :, :]
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def always(val):
|
||||
def inner(*args, **kwargs):
|
||||
return val
|
||||
return inner
|
||||
|
||||
|
||||
def not_equals(val):
|
||||
def inner(x):
|
||||
return x != val
|
||||
return inner
|
||||
|
||||
|
||||
def equals(val):
|
||||
def inner(x):
|
||||
return x == val
|
||||
return inner
|
||||
|
||||
|
||||
def max_neg_value(tensor):
|
||||
return -torch.finfo(tensor.dtype).max
|
||||
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
# classes
|
||||
class Scale(nn.Module):
|
||||
def __init__(self, value, fn):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.value, *rest)
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.g = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.g, *rest)
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def forward(self, x, residual):
|
||||
return x + residual
|
||||
|
||||
|
||||
class GRUGating(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gru = nn.GRUCell(dim, dim)
|
||||
|
||||
def forward(self, x, residual):
|
||||
gated_output = self.gru(
|
||||
rearrange(x, 'b n d -> (b n) d'),
|
||||
rearrange(residual, 'b n d -> (b n) d')
|
||||
)
|
||||
|
||||
return gated_output.reshape_as(x)
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# attention.
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.,
|
||||
on_attn=False
|
||||
):
|
||||
super().__init__()
|
||||
if use_entmax15:
|
||||
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.causal = causal
|
||||
self.mask = mask
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# talking heads
|
||||
self.talking_heads = talking_heads
|
||||
if talking_heads:
|
||||
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
# explicit topk sparse attention
|
||||
self.sparse_topk = sparse_topk
|
||||
|
||||
# entmax
|
||||
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
||||
self.attn_fn = F.softmax
|
||||
|
||||
# add memory key / values
|
||||
self.num_mem_kv = num_mem_kv
|
||||
if num_mem_kv > 0:
|
||||
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
|
||||
# attention on attention
|
||||
self.attn_on_attn = on_attn
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None
|
||||
):
|
||||
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
|
||||
kv_input = default(context, x)
|
||||
|
||||
q_input = x
|
||||
k_input = kv_input
|
||||
v_input = kv_input
|
||||
|
||||
if exists(mem):
|
||||
k_input = torch.cat((mem, k_input), dim=-2)
|
||||
v_input = torch.cat((mem, v_input), dim=-2)
|
||||
|
||||
if exists(sinusoidal_emb):
|
||||
# in shortformer, the query would start at a position offset depending on the past cached memory
|
||||
offset = k_input.shape[-2] - q_input.shape[-2]
|
||||
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
||||
k_input = k_input + sinusoidal_emb(k_input)
|
||||
|
||||
q = self.to_q(q_input)
|
||||
k = self.to_k(k_input)
|
||||
v = self.to_v(v_input)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
||||
k_mask = q_mask if not exists(context) else context_mask
|
||||
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
||||
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
||||
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
||||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(prev_attn):
|
||||
dots = dots + prev_attn
|
||||
|
||||
pre_softmax_attn = dots
|
||||
|
||||
if talking_heads:
|
||||
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
||||
|
||||
if exists(rel_pos):
|
||||
dots = rel_pos(dots)
|
||||
|
||||
if exists(input_mask):
|
||||
dots.masked_fill_(~input_mask, mask_value)
|
||||
del input_mask
|
||||
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
r = torch.arange(i, device=device)
|
||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
||||
mask = F.pad(mask, (j - i, 0), value=False)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
||||
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
||||
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
||||
mask = dots < vk
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = self.attn_fn(dots, dim=-1)
|
||||
post_softmax_attn = attn
|
||||
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if talking_heads:
|
||||
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
intermediates = Intermediates(
|
||||
pre_softmax_attn=pre_softmax_attn,
|
||||
post_softmax_attn=post_softmax_attn
|
||||
)
|
||||
|
||||
return self.to_out(out), intermediates
|
||||
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
||||
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
||||
|
||||
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.has_pos_emb = position_infused_attn
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
self.rotary_pos_emb = always(None)
|
||||
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
self.rel_pos = None
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
self.residual_attn = residual_attn
|
||||
self.cross_residual_attn = cross_residual_attn
|
||||
|
||||
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
||||
norm_class = RMSNorm if use_rmsnorm else norm_class
|
||||
norm_fn = partial(norm_class, dim)
|
||||
|
||||
norm_fn = nn.Identity if use_rezero else norm_fn
|
||||
branch_fn = Rezero if use_rezero else None
|
||||
|
||||
if cross_attend and not only_cross:
|
||||
default_block = ('a', 'c', 'f')
|
||||
elif cross_attend and only_cross:
|
||||
default_block = ('c', 'f')
|
||||
else:
|
||||
default_block = ('a', 'f')
|
||||
|
||||
if macaron:
|
||||
default_block = ('f',) + default_block
|
||||
|
||||
if exists(custom_layers):
|
||||
layer_types = custom_layers
|
||||
elif exists(par_ratio):
|
||||
par_depth = depth * len(default_block)
|
||||
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
||||
default_block = tuple(filter(not_equals('f'), default_block))
|
||||
par_attn = par_depth // par_ratio
|
||||
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
||||
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f',) * (par_width - len(default_block))
|
||||
par_head = par_block * par_attn
|
||||
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
||||
elif exists(sandwich_coef):
|
||||
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
||||
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
||||
else:
|
||||
layer_types = default_block * depth
|
||||
|
||||
self.layer_types = layer_types
|
||||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
if layer_type == 'a':
|
||||
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
||||
elif layer_type == 'c':
|
||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
||||
elif layer_type == 'f':
|
||||
layer = FeedForward(dim, **ff_kwargs)
|
||||
layer = layer if not macaron else Scale(0.5, layer)
|
||||
else:
|
||||
raise Exception(f'invalid layer type {layer_type}')
|
||||
|
||||
if isinstance(layer, Attention) and exists(branch_fn):
|
||||
layer = branch_fn(layer)
|
||||
|
||||
if gate_residual:
|
||||
residual_fn = GRUGating(dim)
|
||||
else:
|
||||
residual_fn = Residual()
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
norm_fn(),
|
||||
layer,
|
||||
residual_fn
|
||||
]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False
|
||||
):
|
||||
hiddens = []
|
||||
intermediates = []
|
||||
prev_attn = None
|
||||
prev_cross_attn = None
|
||||
|
||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
||||
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
|
||||
if layer_type == 'a':
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0)
|
||||
|
||||
residual = x
|
||||
|
||||
if self.pre_norm:
|
||||
x = norm(x)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn, mem=layer_mem)
|
||||
elif layer_type == 'c':
|
||||
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
|
||||
elif layer_type == 'f':
|
||||
out = block(x)
|
||||
|
||||
x = residual_fn(out, residual)
|
||||
|
||||
if layer_type in ('a', 'c'):
|
||||
intermediates.append(inter)
|
||||
|
||||
if layer_type == 'a' and self.residual_attn:
|
||||
prev_attn = inter.pre_softmax_attn
|
||||
elif layer_type == 'c' and self.cross_residual_attn:
|
||||
prev_cross_attn = inter.pre_softmax_attn
|
||||
|
||||
if not self.pre_norm and not is_last:
|
||||
x = norm(x)
|
||||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens=hiddens,
|
||||
attn_intermediates=intermediates
|
||||
)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(AttentionLayers):
|
||||
def __init__(self, **kwargs):
|
||||
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
||||
super().__init__(causal=False, **kwargs)
|
||||
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.,
|
||||
emb_dropout=0.,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
||||
|
||||
dim = attn_layers.dim
|
||||
emb_dim = default(emb_dim, dim)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_mem_len = max_mem_len
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.init_()
|
||||
|
||||
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
|
||||
# memory tokens (like [cls]) from Memory Transformers paper
|
||||
num_memory_tokens = default(num_memory_tokens, 0)
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
|
||||
# let funnel encoder know number of memory tokens, if specified
|
||||
if hasattr(attn_layers, 'num_memory_tokens'):
|
||||
attn_layers.num_memory_tokens = num_memory_tokens
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
**kwargs
|
||||
):
|
||||
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
||||
x = self.token_emb(x)
|
||||
x += self.pos_emb(x)
|
||||
x = self.emb_dropout(x)
|
||||
|
||||
x = self.project_emb(x)
|
||||
|
||||
if num_mem > 0:
|
||||
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
|
||||
x = torch.cat((mem, x), dim=1)
|
||||
|
||||
# auto-handle masking after appending memory tokens
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (num_mem, 0), value=True)
|
||||
|
||||
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
||||
x = self.norm(x)
|
||||
|
||||
mem, x = x[:, :num_mem], x[:, num_mem:]
|
||||
|
||||
out = self.to_logits(x) if not return_embeddings else x
|
||||
|
||||
if return_mems:
|
||||
hiddens = intermediates.hiddens
|
||||
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
|
||||
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
||||
return out, new_mems
|
||||
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
return out, attn_maps
|
||||
|
||||
return out
|
||||
|
|
@ -1,203 +0,0 @@
|
|||
import importlib
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import abc
|
||||
from einops import rearrange
|
||||
from functools import partial
|
||||
|
||||
import multiprocessing as mp
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
|
||||
from inspect import isfunction
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||
# create dummy dataset instance
|
||||
|
||||
# run prefetching
|
||||
if idx_to_fn:
|
||||
res = func(data, worker_id=idx)
|
||||
else:
|
||||
res = func(data)
|
||||
Q.put([idx, res])
|
||||
Q.put("Done")
|
||||
|
||||
|
||||
def parallel_data_prefetch(
|
||||
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
|
||||
):
|
||||
# if target_data_type not in ["ndarray", "list"]:
|
||||
# raise ValueError(
|
||||
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
||||
# )
|
||||
if isinstance(data, np.ndarray) and target_data_type == "list":
|
||||
raise ValueError("list expected but function got ndarray.")
|
||||
elif isinstance(data, abc.Iterable):
|
||||
if isinstance(data, dict):
|
||||
print(
|
||||
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
)
|
||||
data = list(data.values())
|
||||
if target_data_type == "ndarray":
|
||||
data = np.asarray(data)
|
||||
else:
|
||||
data = list(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
||||
)
|
||||
|
||||
if cpu_intensive:
|
||||
Q = mp.Queue(1000)
|
||||
proc = mp.Process
|
||||
else:
|
||||
Q = Queue(1000)
|
||||
proc = Thread
|
||||
# spawn processes
|
||||
if target_data_type == "ndarray":
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate(np.array_split(data, n_proc))
|
||||
]
|
||||
else:
|
||||
step = (
|
||||
int(len(data) / n_proc + 1)
|
||||
if len(data) % n_proc != 0
|
||||
else int(len(data) / n_proc)
|
||||
)
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate(
|
||||
[data[i: i + step] for i in range(0, len(data), step)]
|
||||
)
|
||||
]
|
||||
processes = []
|
||||
for i in range(n_proc):
|
||||
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
|
||||
processes += [p]
|
||||
|
||||
# start processes
|
||||
print(f"Start prefetching...")
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
gather_res = [[] for _ in range(n_proc)]
|
||||
try:
|
||||
for p in processes:
|
||||
p.start()
|
||||
|
||||
k = 0
|
||||
while k < n_proc:
|
||||
# get result
|
||||
res = Q.get()
|
||||
if res == "Done":
|
||||
k += 1
|
||||
else:
|
||||
gather_res[res[0]] = res[1]
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e)
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
|
||||
raise e
|
||||
finally:
|
||||
for p in processes:
|
||||
p.join()
|
||||
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
|
||||
if target_data_type == 'ndarray':
|
||||
if not isinstance(gather_res[0], np.ndarray):
|
||||
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
||||
|
||||
# order outputs
|
||||
return np.concatenate(gather_res, axis=0)
|
||||
elif target_data_type == 'list':
|
||||
out = []
|
||||
for r in gather_res:
|
||||
out.extend(r)
|
||||
return out
|
||||
else:
|
||||
return gather_res
|
|
@ -1,830 +0,0 @@
|
|||
import argparse, os, sys, datetime, glob, importlib, csv
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torchvision
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from packaging import version
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
||||
from functools import partial
|
||||
from PIL import Image
|
||||
# from pytorch_lightning.strategies.colossalai import ColossalAIStrategy
|
||||
# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from diffusers.models.unet_2d import UNet2DModel
|
||||
|
||||
from clip.model import Bottleneck
|
||||
from transformers.models.clip.modeling_clip import CLIPTextTransformer
|
||||
|
||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||
from ldm.util import instantiate_from_config
|
||||
import clip
|
||||
from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
|
||||
from ldm.modules.x_transformer import *
|
||||
from ldm.modules.encoders.modules import *
|
||||
from taming.modules.diffusionmodules.model import ResnetBlock
|
||||
from taming.modules.transformer.mingpt import *
|
||||
from taming.modules.transformer.permuter import *
|
||||
|
||||
|
||||
from ldm.modules.ema import LitEma
|
||||
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||
from ldm.models.autoencoder import AutoencoderKL
|
||||
from ldm.models.autoencoder import *
|
||||
from ldm.models.diffusion.ddim import *
|
||||
from ldm.modules.diffusionmodules.openaimodel import *
|
||||
from ldm.modules.diffusionmodules.model import *
|
||||
from ldm.modules.diffusionmodules.model import Decoder, Encoder, Up_module, Down_module, Mid_module, temb_module
|
||||
from ldm.modules.attention import enable_flash_attention
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
|
||||
def __iter__(self):
|
||||
return BackgroundGenerator(super().__iter__())
|
||||
|
||||
|
||||
def get_parser(**parser_kwargs):
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--name",
|
||||
type=str,
|
||||
const=True,
|
||||
default="",
|
||||
nargs="?",
|
||||
help="postfix for logdir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--resume",
|
||||
type=str,
|
||||
const=True,
|
||||
default="",
|
||||
nargs="?",
|
||||
help="resume from logdir or checkpoint in logdir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--base",
|
||||
nargs="*",
|
||||
metavar="base_config.yaml",
|
||||
help="paths to base configs. Loaded from left-to-right. "
|
||||
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
||||
default=list(),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--train",
|
||||
type=str2bool,
|
||||
const=True,
|
||||
default=False,
|
||||
nargs="?",
|
||||
help="train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-test",
|
||||
type=str2bool,
|
||||
const=True,
|
||||
default=False,
|
||||
nargs="?",
|
||||
help="disable test",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--project",
|
||||
help="name of new or path to existing project"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="enable post-mortem debugging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--seed",
|
||||
type=int,
|
||||
default=23,
|
||||
help="seed for seed_everything",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--postfix",
|
||||
type=str,
|
||||
default="",
|
||||
help="post-postfix for default name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--logdir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help="directory for logging dat shit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=True,
|
||||
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=True,
|
||||
help="whether to use fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash",
|
||||
type=str2bool,
|
||||
const=True,
|
||||
default=False,
|
||||
nargs="?",
|
||||
help="whether to use flash attention",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def nondefault_trainer_args(opt):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args([])
|
||||
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
||||
|
||||
|
||||
class WrappedDataset(Dataset):
|
||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.data = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
|
||||
def worker_init_fn(_):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
|
||||
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
||||
split_size = dataset.num_records // worker_info.num_workers
|
||||
# reset num_records to the true number to retain reliable length information
|
||||
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||
else:
|
||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
|
||||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
||||
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
||||
shuffle_val_dataloader=False):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.dataset_configs = dict()
|
||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||
self.use_worker_init_fn = use_worker_init_fn
|
||||
if train is not None:
|
||||
self.dataset_configs["train"] = train
|
||||
self.train_dataloader = self._train_dataloader
|
||||
if validation is not None:
|
||||
self.dataset_configs["validation"] = validation
|
||||
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
|
||||
if test is not None:
|
||||
self.dataset_configs["test"] = test
|
||||
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
|
||||
if predict is not None:
|
||||
self.dataset_configs["predict"] = predict
|
||||
self.predict_dataloader = self._predict_dataloader
|
||||
self.wrap = wrap
|
||||
|
||||
def prepare_data(self):
|
||||
for data_cfg in self.dataset_configs.values():
|
||||
instantiate_from_config(data_cfg)
|
||||
|
||||
def setup(self, stage=None):
|
||||
self.datasets = dict(
|
||||
(k, instantiate_from_config(self.dataset_configs[k]))
|
||||
for k in self.dataset_configs)
|
||||
if self.wrap:
|
||||
for k in self.datasets:
|
||||
self.datasets[k] = WrappedDataset(self.datasets[k])
|
||||
|
||||
def _train_dataloader(self):
|
||||
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoaderX(self.datasets["train"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
||||
worker_init_fn=init_fn)
|
||||
|
||||
def _val_dataloader(self, shuffle=False):
|
||||
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoaderX(self.datasets["validation"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle)
|
||||
|
||||
def _test_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
|
||||
# do not shuffle dataloader for iterable dataset
|
||||
shuffle = shuffle and (not is_iterable_dataset)
|
||||
|
||||
return DataLoaderX(self.datasets["test"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
|
||||
|
||||
def _predict_dataloader(self, shuffle=False):
|
||||
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoaderX(self.datasets["predict"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, worker_init_fn=init_fn)
|
||||
|
||||
|
||||
class SetupCallback(Callback):
|
||||
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
||||
super().__init__()
|
||||
self.resume = resume
|
||||
self.now = now
|
||||
self.logdir = logdir
|
||||
self.ckptdir = ckptdir
|
||||
self.cfgdir = cfgdir
|
||||
self.config = config
|
||||
self.lightning_config = lightning_config
|
||||
|
||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
print("Summoning checkpoint.")
|
||||
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
# def on_pretrain_routine_start(self, trainer, pl_module):
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
# Create logdirs and save configs
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
os.makedirs(self.ckptdir, exist_ok=True)
|
||||
os.makedirs(self.cfgdir, exist_ok=True)
|
||||
|
||||
if "callbacks" in self.lightning_config:
|
||||
if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
|
||||
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
||||
print("Project config")
|
||||
print(OmegaConf.to_yaml(self.config))
|
||||
OmegaConf.save(self.config,
|
||||
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
||||
|
||||
print("Lightning config")
|
||||
print(OmegaConf.to_yaml(self.lightning_config))
|
||||
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
||||
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
||||
|
||||
else:
|
||||
# ModelCheckpoint callback created log directory --- remove it
|
||||
if not self.resume and os.path.exists(self.logdir):
|
||||
dst, name = os.path.split(self.logdir)
|
||||
dst = os.path.join(dst, "child_runs", name)
|
||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
try:
|
||||
os.rename(self.logdir, dst)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
||||
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
||||
log_images_kwargs=None):
|
||||
super().__init__()
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.logger_log_images = {
|
||||
pl.loggers.CSVLogger: self._testtube,
|
||||
}
|
||||
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||
if not increase_log_steps:
|
||||
self.log_steps = [self.batch_freq]
|
||||
self.clamp = clamp
|
||||
self.disabled = disabled
|
||||
self.log_on_batch_idx = log_on_batch_idx
|
||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||
self.log_first_step = log_first_step
|
||||
|
||||
@rank_zero_only
|
||||
def _testtube(self, pl_module, images, batch_idx, split):
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k])
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
|
||||
tag = f"{split}/{k}"
|
||||
pl_module.logger.experiment.add_image(
|
||||
tag, grid,
|
||||
global_step=pl_module.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def log_local(self, save_dir, split, images,
|
||||
global_step, current_epoch, batch_idx):
|
||||
root = os.path.join(save_dir, "images", split)
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||
if self.rescale:
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
||||
k,
|
||||
global_step,
|
||||
current_epoch,
|
||||
batch_idx)
|
||||
path = os.path.join(root, filename)
|
||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||
Image.fromarray(grid).save(path)
|
||||
|
||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
||||
hasattr(pl_module, "log_images") and
|
||||
callable(pl_module.log_images) and
|
||||
self.max_images > 0):
|
||||
logger = type(pl_module.logger)
|
||||
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
||||
|
||||
for k in images:
|
||||
N = min(images[k].shape[0], self.max_images)
|
||||
images[k] = images[k][:N]
|
||||
if isinstance(images[k], torch.Tensor):
|
||||
images[k] = images[k].detach().cpu()
|
||||
if self.clamp:
|
||||
images[k] = torch.clamp(images[k], -1., 1.)
|
||||
|
||||
self.log_local(pl_module.logger.save_dir, split, images,
|
||||
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
||||
|
||||
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||
|
||||
if is_train:
|
||||
pl_module.train()
|
||||
|
||||
def check_frequency(self, check_idx):
|
||||
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
||||
check_idx > 0 or self.log_first_step):
|
||||
try:
|
||||
self.log_steps.pop(0)
|
||||
except IndexError as e:
|
||||
print(e)
|
||||
pass
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
# if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
||||
# self.log_img(pl_module, batch, batch_idx, split="train")
|
||||
pass
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if not self.disabled and pl_module.global_step > 0:
|
||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||
if hasattr(pl_module, 'calibrate_grad_norm'):
|
||||
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
||||
|
||||
|
||||
class CUDACallback(Callback):
|
||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
rank_zero_info("Training is starting")
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
rank_zero_info("Training is ending")
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
# Reset the memory use counter
|
||||
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
max_memory = trainer.strategy.reduce(max_memory)
|
||||
epoch_time = trainer.strategy.reduce(epoch_time)
|
||||
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# custom parser to specify config files, train, test and debug mode,
|
||||
# postfix, resume.
|
||||
# `--key value` arguments are interpreted as arguments to the trainer.
|
||||
# `nested.key=value` arguments are interpreted as config parameters.
|
||||
# configs are merged from left-to-right followed by command line parameters.
|
||||
|
||||
# model:
|
||||
# base_learning_rate: float
|
||||
# target: path to lightning module
|
||||
# params:
|
||||
# key: value
|
||||
# data:
|
||||
# target: main.DataModuleFromConfig
|
||||
# params:
|
||||
# batch_size: int
|
||||
# wrap: bool
|
||||
# train:
|
||||
# target: path to train dataset
|
||||
# params:
|
||||
# key: value
|
||||
# validation:
|
||||
# target: path to validation dataset
|
||||
# params:
|
||||
# key: value
|
||||
# test:
|
||||
# target: path to test dataset
|
||||
# params:
|
||||
# key: value
|
||||
# lightning: (optional, has sane defaults and can be specified on cmdline)
|
||||
# trainer:
|
||||
# additional arguments to trainer
|
||||
# logger:
|
||||
# logger to instantiate
|
||||
# modelcheckpoint:
|
||||
# modelcheckpoint to instantiate
|
||||
# callbacks:
|
||||
# callback1:
|
||||
# target: importpath
|
||||
# params:
|
||||
# key: value
|
||||
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
|
||||
# add cwd for convenience and to make classes in this file available when
|
||||
# running as `python main.py`
|
||||
# (in particular `main.DataModuleFromConfig`)
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
parser = get_parser()
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
|
||||
opt, unknown = parser.parse_known_args()
|
||||
if opt.name and opt.resume:
|
||||
raise ValueError(
|
||||
"-n/--name and -r/--resume cannot be specified both."
|
||||
"If you want to resume training in a new log folder, "
|
||||
"use -n/--name in combination with --resume_from_checkpoint"
|
||||
)
|
||||
if opt.flash:
|
||||
enable_flash_attention()
|
||||
if opt.resume:
|
||||
if not os.path.exists(opt.resume):
|
||||
raise ValueError("Cannot find {}".format(opt.resume))
|
||||
if os.path.isfile(opt.resume):
|
||||
paths = opt.resume.split("/")
|
||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
||||
# logdir = "/".join(paths[:idx])
|
||||
logdir = "/".join(paths[:-2])
|
||||
ckpt = opt.resume
|
||||
else:
|
||||
assert os.path.isdir(opt.resume), opt.resume
|
||||
logdir = opt.resume.rstrip("/")
|
||||
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
||||
|
||||
opt.resume_from_checkpoint = ckpt
|
||||
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
||||
opt.base = base_configs + opt.base
|
||||
_tmp = logdir.split("/")
|
||||
nowname = _tmp[-1]
|
||||
else:
|
||||
if opt.name:
|
||||
name = "_" + opt.name
|
||||
elif opt.base:
|
||||
cfg_fname = os.path.split(opt.base[0])[-1]
|
||||
cfg_name = os.path.splitext(cfg_fname)[0]
|
||||
name = "_" + cfg_name
|
||||
else:
|
||||
name = ""
|
||||
nowname = now + name + opt.postfix
|
||||
logdir = os.path.join(opt.logdir, nowname)
|
||||
|
||||
ckptdir = os.path.join(logdir, "checkpoints")
|
||||
cfgdir = os.path.join(logdir, "configs")
|
||||
seed_everything(opt.seed)
|
||||
|
||||
try:
|
||||
# init and save configs
|
||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
||||
cli = OmegaConf.from_dotlist(unknown)
|
||||
config = OmegaConf.merge(*configs, cli)
|
||||
lightning_config = config.pop("lightning", OmegaConf.create())
|
||||
# merge trainer cli with config
|
||||
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
||||
|
||||
for k in nondefault_trainer_args(opt):
|
||||
trainer_config[k] = getattr(opt, k)
|
||||
|
||||
print(trainer_config)
|
||||
if not trainer_config["accelerator"] == "gpu":
|
||||
del trainer_config["accelerator"]
|
||||
cpu = True
|
||||
print("Running on CPU")
|
||||
else:
|
||||
cpu = False
|
||||
print("Running on GPU")
|
||||
trainer_opt = argparse.Namespace(**trainer_config)
|
||||
lightning_config.trainer = trainer_config
|
||||
|
||||
# model
|
||||
use_fp16 = trainer_config.get("precision", 32) == 16
|
||||
if use_fp16:
|
||||
config.model["params"].update({"use_fp16": True})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
else:
|
||||
config.model["params"].update({"use_fp16": False})
|
||||
print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
# trainer and callbacks
|
||||
trainer_kwargs = dict()
|
||||
|
||||
# config the logger
|
||||
# default logger configs
|
||||
default_logger_cfgs = {
|
||||
"wandb": {
|
||||
"target": "pytorch_lightning.loggers.WandbLogger",
|
||||
"params": {
|
||||
"name": nowname,
|
||||
"save_dir": logdir,
|
||||
"offline": opt.debug,
|
||||
"id": nowname,
|
||||
}
|
||||
},
|
||||
"tensorboard":{
|
||||
"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
||||
"params":{
|
||||
"save_dir": logdir,
|
||||
"name": "diff_tb",
|
||||
"log_graph": True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default_logger_cfg = default_logger_cfgs["tensorboard"]
|
||||
if "logger" in lightning_config:
|
||||
logger_cfg = lightning_config.logger
|
||||
else:
|
||||
logger_cfg = default_logger_cfg
|
||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
||||
|
||||
# config the strategy, defualt is ddp
|
||||
if "strategy" in trainer_config:
|
||||
strategy_cfg = trainer_config["strategy"]
|
||||
print("Using strategy: {}".format(strategy_cfg["target"]))
|
||||
else:
|
||||
strategy_cfg = {
|
||||
"target": "pytorch_lightning.strategies.DDPStrategy",
|
||||
"params": {
|
||||
"find_unused_parameters": False
|
||||
}
|
||||
}
|
||||
print("Using strategy: DDPStrategy")
|
||||
|
||||
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
|
||||
|
||||
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
||||
# specify which metric is used to determine best models
|
||||
default_modelckpt_cfg = {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{epoch:06}",
|
||||
"verbose": True,
|
||||
"save_last": True,
|
||||
}
|
||||
}
|
||||
if hasattr(model, "monitor"):
|
||||
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
||||
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
||||
default_modelckpt_cfg["params"]["save_top_k"] = 3
|
||||
|
||||
if "modelcheckpoint" in lightning_config:
|
||||
modelckpt_cfg = lightning_config.modelcheckpoint
|
||||
else:
|
||||
modelckpt_cfg = OmegaConf.create()
|
||||
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
||||
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
||||
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
||||
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
||||
|
||||
# add callback which sets up log directory
|
||||
default_callbacks_cfg = {
|
||||
"setup_callback": {
|
||||
"target": "main.SetupCallback",
|
||||
"params": {
|
||||
"resume": opt.resume,
|
||||
"now": now,
|
||||
"logdir": logdir,
|
||||
"ckptdir": ckptdir,
|
||||
"cfgdir": cfgdir,
|
||||
"config": config,
|
||||
"lightning_config": lightning_config,
|
||||
}
|
||||
},
|
||||
"image_logger": {
|
||||
"target": "main.ImageLogger",
|
||||
"params": {
|
||||
"batch_frequency": 750,
|
||||
"max_images": 4,
|
||||
"clamp": True
|
||||
}
|
||||
},
|
||||
"learning_rate_logger": {
|
||||
"target": "main.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step",
|
||||
# "log_momentum": True
|
||||
}
|
||||
},
|
||||
"cuda_callback": {
|
||||
"target": "main.CUDACallback"
|
||||
},
|
||||
}
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
|
||||
|
||||
if "callbacks" in lightning_config:
|
||||
callbacks_cfg = lightning_config.callbacks
|
||||
else:
|
||||
callbacks_cfg = OmegaConf.create()
|
||||
|
||||
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
|
||||
print(
|
||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
||||
default_metrics_over_trainsteps_ckpt_dict = {
|
||||
'metrics_over_trainsteps_checkpoint':
|
||||
{"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch:06}-{step:09}",
|
||||
"verbose": True,
|
||||
'save_top_k': -1,
|
||||
'every_n_train_steps': 10000,
|
||||
'save_weights_only': True
|
||||
}
|
||||
}
|
||||
}
|
||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
|
||||
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
|
||||
elif 'ignore_keys_callback' in callbacks_cfg:
|
||||
del callbacks_cfg['ignore_keys_callback']
|
||||
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
|
||||
# data
|
||||
data = instantiate_from_config(config.data)
|
||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
# calling these ourselves should not be necessary but it is.
|
||||
# lightning still takes care of proper multiprocessing though
|
||||
data.prepare_data()
|
||||
data.setup()
|
||||
print("#### Data #####")
|
||||
for k in data.datasets:
|
||||
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
||||
|
||||
# configure learning rate
|
||||
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
||||
if not cpu:
|
||||
ngpu = trainer_config["devices"]
|
||||
else:
|
||||
ngpu = 1
|
||||
if 'accumulate_grad_batches' in lightning_config.trainer:
|
||||
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
||||
else:
|
||||
accumulate_grad_batches = 1
|
||||
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
||||
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||
if opt.scale_lr:
|
||||
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
|
||||
print(
|
||||
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
||||
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
|
||||
else:
|
||||
model.learning_rate = base_lr
|
||||
print("++++ NOT USING LR SCALING ++++")
|
||||
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
||||
|
||||
|
||||
# allow checkpointing via USR1
|
||||
def melk(*args, **kwargs):
|
||||
# run all checkpoint hooks
|
||||
if trainer.global_rank == 0:
|
||||
print("Summoning checkpoint.")
|
||||
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
|
||||
def divein(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
import pudb;
|
||||
pudb.set_trace()
|
||||
|
||||
|
||||
import signal
|
||||
|
||||
signal.signal(signal.SIGUSR1, melk)
|
||||
signal.signal(signal.SIGUSR2, divein)
|
||||
|
||||
# run
|
||||
if opt.train:
|
||||
try:
|
||||
for name, m in model.named_parameters():
|
||||
print(name)
|
||||
trainer.fit(model, data)
|
||||
except Exception:
|
||||
melk()
|
||||
raise
|
||||
# if not opt.no_test and not trainer.interrupted:
|
||||
# trainer.test(model, data)
|
||||
except Exception:
|
||||
if opt.debug and trainer.global_rank == 0:
|
||||
try:
|
||||
import pudb as debugger
|
||||
except ImportError:
|
||||
import pdb as debugger
|
||||
debugger.post_mortem()
|
||||
raise
|
||||
finally:
|
||||
# move newly created debug project to debug_runs
|
||||
if opt.debug and not opt.resume and trainer.global_rank == 0:
|
||||
dst, name = os.path.split(logdir)
|
||||
dst = os.path.join(dst, "debug_runs", name)
|
||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
os.rename(logdir, dst)
|
||||
if trainer.global_rank == 0:
|
||||
print(trainer.profiler.summary())
|
|
@ -1,20 +0,0 @@
|
|||
albumentations==0.4.3
|
||||
diffusers
|
||||
opencv-python==4.1.2.30
|
||||
pudb==2019.2
|
||||
invisible-watermark
|
||||
imageio==2.9.0
|
||||
imageio-ffmpeg==0.4.2
|
||||
omegaconf==2.1.1
|
||||
test-tube>=0.7.5
|
||||
streamlit>=0.73.1
|
||||
einops==0.3.0
|
||||
torch-fidelity==0.3.0
|
||||
transformers==4.19.2
|
||||
torchmetrics==0.6.0
|
||||
kornia==0.6
|
||||
opencv-python==4.6.0.66
|
||||
prefetch_generator
|
||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
-e .
|
|
@ -1,41 +0,0 @@
|
|||
#!/bin/bash
|
||||
wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
|
||||
wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
|
||||
wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
|
||||
wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
|
||||
wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
|
||||
wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
|
||||
wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
|
||||
wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
|
||||
wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
|
||||
|
||||
|
||||
|
||||
cd models/first_stage_models/kl-f4
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../kl-f8
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../kl-f16
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../kl-f32
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f4
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f4-noattn
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f8
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f8-n256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../vq-f16
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../..
|
|
@ -1,49 +0,0 @@
|
|||
#!/bin/bash
|
||||
wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
|
||||
wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
|
||||
wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
|
||||
wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
|
||||
wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
|
||||
wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
|
||||
wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
|
||||
wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
|
||||
wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
|
||||
wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
|
||||
wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
|
||||
|
||||
|
||||
|
||||
cd models/ldm/celeba256
|
||||
unzip -o celeba-256.zip
|
||||
|
||||
cd ../ffhq256
|
||||
unzip -o ffhq-256.zip
|
||||
|
||||
cd ../lsun_churches256
|
||||
unzip -o lsun_churches-256.zip
|
||||
|
||||
cd ../lsun_beds256
|
||||
unzip -o lsun_beds-256.zip
|
||||
|
||||
cd ../text2img256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../cin256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../semantic_synthesis512
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../semantic_synthesis256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../bsr_sr
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../layout2img-openimages256
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../inpainting_big
|
||||
unzip -o model.zip
|
||||
|
||||
cd ../..
|
|
@ -1,293 +0,0 @@
|
|||
"""make variations of input image"""
|
||||
|
||||
import argparse, os, sys, glob
|
||||
import PIL
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_img(path):
|
||||
image = Image.open(path).convert("RGB")
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.*image - 1.
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a virus monster playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--init-img",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="path to the input image"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
default="outputs/img2img-samples"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action='store_true',
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_save",
|
||||
action='store_true',
|
||||
help="do not save indiviual samples. For speed measurements.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--plms",
|
||||
action='store_true',
|
||||
help="use plms sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fixed_code",
|
||||
action='store_true',
|
||||
help="if enabled, uses the same starting code across all samples ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iter",
|
||||
type=int,
|
||||
default=1,
|
||||
help="sample this often",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--C",
|
||||
type=int,
|
||||
default=4,
|
||||
help="latent channels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--f",
|
||||
type=int,
|
||||
default=8,
|
||||
help="downsampling factor, most often 8 or 16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_samples",
|
||||
type=int,
|
||||
default=2,
|
||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_rows",
|
||||
type=int,
|
||||
default=0,
|
||||
help="rows in the grid (default: n_samples)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
help="if specified, load prompts from this file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="configs/stable-diffusion/v1-inference.yaml",
|
||||
help="path to config which constructs model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||
help="path to checkpoint of model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed (for reproducible sampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
help="evaluate at this precision",
|
||||
choices=["full", "autocast"],
|
||||
default="autocast"
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
seed_everything(opt.seed)
|
||||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
|
||||
if opt.plms:
|
||||
raise NotImplementedError("PLMS sampler not (yet) supported")
|
||||
sampler = PLMSSampler(model)
|
||||
else:
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
prompt = opt.prompt
|
||||
assert prompt is not None
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = list(chunk(data, batch_size))
|
||||
|
||||
sample_path = os.path.join(outpath, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
assert os.path.isfile(opt.init_img)
|
||||
init_image = load_img(opt.init_img).to(device)
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
|
||||
|
||||
assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(opt.strength * opt.ddim_steps)
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
|
||||
# decode it
|
||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,)
|
||||
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if not opt.skip_save:
|
||||
for x_sample in x_samples:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||
base_count += 1
|
||||
all_samples.append(x_samples)
|
||||
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||
f" \nEnjoy.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,98 +0,0 @@
|
|||
import argparse, os, sys, glob
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
from main import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
|
||||
def make_batch(image, mask, device):
|
||||
image = np.array(Image.open(image).convert("RGB"))
|
||||
image = image.astype(np.float32)/255.0
|
||||
image = image[None].transpose(0,3,1,2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
mask = np.array(Image.open(mask).convert("L"))
|
||||
mask = mask.astype(np.float32)/255.0
|
||||
mask = mask[None,None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = (1-mask)*image
|
||||
|
||||
batch = {"image": image, "mask": mask, "masked_image": masked_image}
|
||||
for k in batch:
|
||||
batch[k] = batch[k].to(device=device)
|
||||
batch[k] = batch[k]*2.0-1.0
|
||||
return batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--indir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
|
||||
images = [x.replace("_mask.png", ".png") for x in masks]
|
||||
print(f"Found {len(masks)} inputs.")
|
||||
|
||||
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
|
||||
strict=False)
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
with torch.no_grad():
|
||||
with model.ema_scope():
|
||||
for image, mask in tqdm(zip(images, masks)):
|
||||
outpath = os.path.join(opt.outdir, os.path.split(image)[1])
|
||||
batch = make_batch(image, mask, device=device)
|
||||
|
||||
# encode masked image and concat downsampled mask
|
||||
c = model.cond_stage_model.encode(batch["masked_image"])
|
||||
cc = torch.nn.functional.interpolate(batch["mask"],
|
||||
size=c.shape[-2:])
|
||||
c = torch.cat((c, cc), dim=1)
|
||||
|
||||
shape = (c.shape[1]-1,)+c.shape[2:]
|
||||
samples_ddim, _ = sampler.sample(S=opt.steps,
|
||||
conditioning=c,
|
||||
batch_size=c.shape[0],
|
||||
shape=shape,
|
||||
verbose=False)
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
|
||||
image = torch.clamp((batch["image"]+1.0)/2.0,
|
||||
min=0.0, max=1.0)
|
||||
mask = torch.clamp((batch["mask"]+1.0)/2.0,
|
||||
min=0.0, max=1.0)
|
||||
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
|
||||
min=0.0, max=1.0)
|
||||
|
||||
inpainted = (1-mask)*image+mask*predicted_image
|
||||
inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
|
||||
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
|
|
@ -1,398 +0,0 @@
|
|||
import argparse, os, sys, glob
|
||||
import clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
import scann
|
||||
import time
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from ldm.util import instantiate_from_config, parallel_data_prefetch
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
|
||||
|
||||
DATABASES = [
|
||||
"openimages",
|
||||
"artbench-art_nouveau",
|
||||
"artbench-baroque",
|
||||
"artbench-expressionism",
|
||||
"artbench-impressionism",
|
||||
"artbench-post_impressionism",
|
||||
"artbench-realism",
|
||||
"artbench-romanticism",
|
||||
"artbench-renaissance",
|
||||
"artbench-surrealism",
|
||||
"artbench-ukiyo_e",
|
||||
]
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class Searcher(object):
|
||||
def __init__(self, database, retriever_version='ViT-L/14'):
|
||||
assert database in DATABASES
|
||||
# self.database = self.load_database(database)
|
||||
self.database_name = database
|
||||
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
|
||||
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
|
||||
self.retriever = self.load_retriever(version=retriever_version)
|
||||
self.database = {'embedding': [],
|
||||
'img_id': [],
|
||||
'patch_coords': []}
|
||||
self.load_database()
|
||||
self.load_searcher()
|
||||
|
||||
def train_searcher(self, k,
|
||||
metric='dot_product',
|
||||
searcher_savedir=None):
|
||||
|
||||
print('Start training searcher')
|
||||
searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
|
||||
np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
|
||||
k, metric)
|
||||
self.searcher = searcher.score_brute_force().build()
|
||||
print('Finish training searcher')
|
||||
|
||||
if searcher_savedir is not None:
|
||||
print(f'Save trained searcher under "{searcher_savedir}"')
|
||||
os.makedirs(searcher_savedir, exist_ok=True)
|
||||
self.searcher.serialize(searcher_savedir)
|
||||
|
||||
def load_single_file(self, saved_embeddings):
|
||||
compressed = np.load(saved_embeddings)
|
||||
self.database = {key: compressed[key] for key in compressed.files}
|
||||
print('Finished loading of clip embeddings.')
|
||||
|
||||
def load_multi_files(self, data_archive):
|
||||
out_data = {key: [] for key in self.database}
|
||||
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
|
||||
for key in d.files:
|
||||
out_data[key].append(d[key])
|
||||
|
||||
return out_data
|
||||
|
||||
def load_database(self):
|
||||
|
||||
print(f'Load saved patch embedding from "{self.database_path}"')
|
||||
file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
|
||||
|
||||
if len(file_content) == 1:
|
||||
self.load_single_file(file_content[0])
|
||||
elif len(file_content) > 1:
|
||||
data = [np.load(f) for f in file_content]
|
||||
prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
|
||||
n_proc=min(len(data), cpu_count()), target_data_type='dict')
|
||||
|
||||
self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
|
||||
self.database}
|
||||
else:
|
||||
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
|
||||
|
||||
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
|
||||
|
||||
def load_retriever(self, version='ViT-L/14', ):
|
||||
model = FrozenClipImageEmbedder(model=version)
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def load_searcher(self):
|
||||
print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
|
||||
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
|
||||
print('Finished loading searcher.')
|
||||
|
||||
def search(self, x, k):
|
||||
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
|
||||
self.train_searcher(k) # quickly fit searcher on the fly for small databases
|
||||
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().cpu().numpy()
|
||||
if len(x.shape) == 3:
|
||||
x = x[:, 0]
|
||||
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
|
||||
|
||||
start = time.time()
|
||||
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
|
||||
end = time.time()
|
||||
|
||||
out_embeddings = self.database['embedding'][nns]
|
||||
out_img_ids = self.database['img_id'][nns]
|
||||
out_pc = self.database['patch_coords'][nns]
|
||||
|
||||
out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
|
||||
'img_ids': out_img_ids,
|
||||
'patch_coords': out_pc,
|
||||
'queries': x,
|
||||
'exec_time': end - start,
|
||||
'nns': nns,
|
||||
'q_embeddings': query_embeddings}
|
||||
|
||||
return out
|
||||
|
||||
def __call__(self, x, n):
|
||||
return self.search(x, n)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
|
||||
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a virus monster playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
default="outputs/txt2img-samples"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action='store_true',
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_repeat",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of repeats in CLIP latent space",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--plms",
|
||||
action='store_true',
|
||||
help="use plms sampling",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iter",
|
||||
type=int,
|
||||
default=1,
|
||||
help="sample this often",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--H",
|
||||
type=int,
|
||||
default=768,
|
||||
help="image height, in pixel space",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--W",
|
||||
type=int,
|
||||
default=768,
|
||||
help="image width, in pixel space",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_samples",
|
||||
type=int,
|
||||
default=3,
|
||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_rows",
|
||||
type=int,
|
||||
default=0,
|
||||
help="rows in the grid (default: n_samples)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
help="if specified, load prompts from this file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="configs/retrieval-augmented-diffusion/768x768.yaml",
|
||||
help="path to config which constructs model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="models/rdm/rdm768x768/model.ckpt",
|
||||
help="path to checkpoint of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--clip_type",
|
||||
type=str,
|
||||
default="ViT-L/14",
|
||||
help="which CLIP model to use for retrieval and NN encoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--database",
|
||||
type=str,
|
||||
default='artbench-surrealism',
|
||||
choices=DATABASES,
|
||||
help="The database used for the search, only applied when --use_neighbors=True",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_neighbors",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Include neighbors in addition to text prompt for conditioning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--knn",
|
||||
default=10,
|
||||
type=int,
|
||||
help="The number of included neighbors, only applied when --use_neighbors=True",
|
||||
)
|
||||
|
||||
opt = parser.parse_args()
|
||||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
|
||||
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
|
||||
|
||||
if opt.plms:
|
||||
sampler = PLMSSampler(model)
|
||||
else:
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
prompt = opt.prompt
|
||||
assert prompt is not None
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = list(chunk(data, batch_size))
|
||||
|
||||
sample_path = os.path.join(outpath, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
print(f"sampling scale for cfg is {opt.scale:.2f}")
|
||||
|
||||
searcher = None
|
||||
if opt.use_neighbors:
|
||||
searcher = Searcher(opt.database)
|
||||
|
||||
with torch.no_grad():
|
||||
with model.ema_scope():
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
all_samples = list()
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
print("sampling prompts:", prompts)
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = clip_text_encoder.encode(prompts)
|
||||
uc = None
|
||||
if searcher is not None:
|
||||
nn_dict = searcher(c, opt.knn)
|
||||
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
|
||||
if opt.scale != 1.0:
|
||||
uc = torch.zeros_like(c)
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
|
||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||
conditioning=c,
|
||||
batch_size=c.shape[0],
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt.ddim_eta,
|
||||
)
|
||||
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||
base_count += 1
|
||||
all_samples.append(x_samples_ddim)
|
||||
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
|
||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
|
@ -1,313 +0,0 @@
|
|||
import argparse, os, sys, glob, datetime, yaml
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
rescale = lambda x: (x + 1.) / 2.
|
||||
|
||||
def custom_to_pil(x):
|
||||
x = x.detach().cpu()
|
||||
x = torch.clamp(x, -1., 1.)
|
||||
x = (x + 1.) / 2.
|
||||
x = x.permute(1, 2, 0).numpy()
|
||||
x = (255 * x).astype(np.uint8)
|
||||
x = Image.fromarray(x)
|
||||
if not x.mode == "RGB":
|
||||
x = x.convert("RGB")
|
||||
return x
|
||||
|
||||
|
||||
def custom_to_np(x):
|
||||
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
|
||||
sample = x.detach().cpu()
|
||||
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
|
||||
sample = sample.permute(0, 2, 3, 1)
|
||||
sample = sample.contiguous()
|
||||
return sample
|
||||
|
||||
|
||||
def logs2pil(logs, keys=["sample"]):
|
||||
imgs = dict()
|
||||
for k in logs:
|
||||
try:
|
||||
if len(logs[k].shape) == 4:
|
||||
img = custom_to_pil(logs[k][0, ...])
|
||||
elif len(logs[k].shape) == 3:
|
||||
img = custom_to_pil(logs[k])
|
||||
else:
|
||||
print(f"Unknown format for key {k}. ")
|
||||
img = None
|
||||
except:
|
||||
img = None
|
||||
imgs[k] = img
|
||||
return imgs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convsample(model, shape, return_intermediates=True,
|
||||
verbose=True,
|
||||
make_prog_row=False):
|
||||
|
||||
|
||||
if not make_prog_row:
|
||||
return model.p_sample_loop(None, shape,
|
||||
return_intermediates=return_intermediates, verbose=verbose)
|
||||
else:
|
||||
return model.progressive_denoising(
|
||||
None, shape, verbose=True
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convsample_ddim(model, steps, shape, eta=1.0
|
||||
):
|
||||
ddim = DDIMSampler(model)
|
||||
bs = shape[0]
|
||||
shape = shape[1:]
|
||||
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
|
||||
|
||||
|
||||
log = dict()
|
||||
|
||||
shape = [batch_size,
|
||||
model.model.diffusion_model.in_channels,
|
||||
model.model.diffusion_model.image_size,
|
||||
model.model.diffusion_model.image_size]
|
||||
|
||||
with model.ema_scope("Plotting"):
|
||||
t0 = time.time()
|
||||
if vanilla:
|
||||
sample, progrow = convsample(model, shape,
|
||||
make_prog_row=True)
|
||||
else:
|
||||
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
|
||||
eta=eta)
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
x_sample = model.decode_first_stage(sample)
|
||||
|
||||
log["sample"] = x_sample
|
||||
log["time"] = t1 - t0
|
||||
log['throughput'] = sample.shape[0] / (t1 - t0)
|
||||
print(f'Throughput for this batch: {log["throughput"]}')
|
||||
return log
|
||||
|
||||
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
|
||||
if vanilla:
|
||||
print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
|
||||
else:
|
||||
print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
|
||||
|
||||
|
||||
tstart = time.time()
|
||||
n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
|
||||
# path = logdir
|
||||
if model.cond_stage_model is None:
|
||||
all_images = []
|
||||
|
||||
print(f"Running unconditional sampling for {n_samples} samples")
|
||||
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
|
||||
logs = make_convolutional_sample(model, batch_size=batch_size,
|
||||
vanilla=vanilla, custom_steps=custom_steps,
|
||||
eta=eta)
|
||||
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
|
||||
all_images.extend([custom_to_np(logs["sample"])])
|
||||
if n_saved >= n_samples:
|
||||
print(f'Finish after generating {n_saved} samples')
|
||||
break
|
||||
all_img = np.concatenate(all_images, axis=0)
|
||||
all_img = all_img[:n_samples]
|
||||
shape_str = "x".join([str(x) for x in all_img.shape])
|
||||
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
|
||||
np.savez(nppath, all_img)
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Currently only sampling for unconditional models supported.')
|
||||
|
||||
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
|
||||
|
||||
|
||||
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
|
||||
for k in logs:
|
||||
if k == key:
|
||||
batch = logs[key]
|
||||
if np_path is None:
|
||||
for x in batch:
|
||||
img = custom_to_pil(x)
|
||||
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
|
||||
img.save(imgpath)
|
||||
n_saved += 1
|
||||
else:
|
||||
npbatch = custom_to_np(batch)
|
||||
shape_str = "x".join([str(x) for x in npbatch.shape])
|
||||
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
|
||||
np.savez(nppath, npbatch)
|
||||
n_saved += npbatch.shape[0]
|
||||
return n_saved
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--resume",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="load from logdir or checkpoint in logdir",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--n_samples",
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="number of samples to draw",
|
||||
default=50000
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--eta",
|
||||
type=float,
|
||||
nargs="?",
|
||||
help="eta for ddim sampling (0.0 yields deterministic sampling)",
|
||||
default=1.0
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--vanilla_sample",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="vanilla sampling (default option is DDIM sampling)?",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--logdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="extra logdir",
|
||||
default="none"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--custom_steps",
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="number of steps for ddim and fastdpm sampling",
|
||||
default=50
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
nargs="?",
|
||||
help="the bs",
|
||||
default=10
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def load_model_from_config(config, sd):
|
||||
model = instantiate_from_config(config)
|
||||
model.load_state_dict(sd,strict=False)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_model(config, ckpt, gpu, eval_mode):
|
||||
if ckpt:
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
global_step = pl_sd["global_step"]
|
||||
else:
|
||||
pl_sd = {"state_dict": None}
|
||||
global_step = None
|
||||
model = load_model_from_config(config.model,
|
||||
pl_sd["state_dict"])
|
||||
|
||||
return model, global_step
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
sys.path.append(os.getcwd())
|
||||
command = " ".join(sys.argv)
|
||||
|
||||
parser = get_parser()
|
||||
opt, unknown = parser.parse_known_args()
|
||||
ckpt = None
|
||||
|
||||
if not os.path.exists(opt.resume):
|
||||
raise ValueError("Cannot find {}".format(opt.resume))
|
||||
if os.path.isfile(opt.resume):
|
||||
# paths = opt.resume.split("/")
|
||||
try:
|
||||
logdir = '/'.join(opt.resume.split('/')[:-1])
|
||||
# idx = len(paths)-paths[::-1].index("logs")+1
|
||||
print(f'Logdir is {logdir}')
|
||||
except ValueError:
|
||||
paths = opt.resume.split("/")
|
||||
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
|
||||
logdir = "/".join(paths[:idx])
|
||||
ckpt = opt.resume
|
||||
else:
|
||||
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
|
||||
logdir = opt.resume.rstrip("/")
|
||||
ckpt = os.path.join(logdir, "model.ckpt")
|
||||
|
||||
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
|
||||
opt.base = base_configs
|
||||
|
||||
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
||||
cli = OmegaConf.from_dotlist(unknown)
|
||||
config = OmegaConf.merge(*configs, cli)
|
||||
|
||||
gpu = True
|
||||
eval_mode = True
|
||||
|
||||
if opt.logdir != "none":
|
||||
locallog = logdir.split(os.sep)[-1]
|
||||
if locallog == "": locallog = logdir.split(os.sep)[-2]
|
||||
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
|
||||
logdir = os.path.join(opt.logdir, locallog)
|
||||
|
||||
print(config)
|
||||
|
||||
model, global_step = load_model(config, ckpt, gpu, eval_mode)
|
||||
print(f"global step: {global_step}")
|
||||
print(75 * "=")
|
||||
print("logging to:")
|
||||
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
|
||||
imglogdir = os.path.join(logdir, "img")
|
||||
numpylogdir = os.path.join(logdir, "numpy")
|
||||
|
||||
os.makedirs(imglogdir)
|
||||
os.makedirs(numpylogdir)
|
||||
print(logdir)
|
||||
print(75 * "=")
|
||||
|
||||
# write config out
|
||||
sampling_file = os.path.join(logdir, "sampling_config.yaml")
|
||||
sampling_conf = vars(opt)
|
||||
|
||||
with open(sampling_file, 'w') as f:
|
||||
yaml.dump(sampling_conf, f, default_flow_style=False)
|
||||
print(sampling_conf)
|
||||
|
||||
|
||||
run(model, imglogdir, eta=opt.eta,
|
||||
vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
|
||||
batch_size=opt.batch_size, nplog=numpylogdir)
|
||||
|
||||
print("done.")
|
|
@ -1,37 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
from ldm.util import instantiate_from_config
|
||||
from main import get_parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.no_grad():
|
||||
yaml_path = "../../train_colossalai.yaml"
|
||||
with open(yaml_path, 'r', encoding='utf-8') as f:
|
||||
config = f.read()
|
||||
base_config = yaml.load(config, Loader=yaml.FullLoader)
|
||||
unet_config = base_config['model']['params']['unet_config']
|
||||
diffusion_model = instantiate_from_config(unet_config).to("cuda:0")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"/data/scratch/diffuser/stable-diffusion-v1-4"
|
||||
).to("cuda:0")
|
||||
dif_model_2 = pipe.unet
|
||||
|
||||
random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0")
|
||||
random_input_2 = torch.clone(random_input_).to("cuda:0")
|
||||
time_stamp = torch.randint(20, (4,)).to("cuda:0")
|
||||
time_stamp2 = torch.clone(time_stamp).to("cuda:0")
|
||||
context_ = torch.rand((4, 77, 768)).to("cuda:0")
|
||||
context_2 = torch.clone(context_).to("cuda:0")
|
||||
|
||||
out_1 = diffusion_model(random_input_, time_stamp, context_)
|
||||
out_2 = dif_model_2(random_input_2, time_stamp2, context_2)
|
||||
print(out_1.shape)
|
||||
print(out_2['sample'].shape)
|
|
@ -1,18 +0,0 @@
|
|||
import cv2
|
||||
import fire
|
||||
from imwatermark import WatermarkDecoder
|
||||
|
||||
|
||||
def testit(img_path):
|
||||
bgr = cv2.imread(img_path)
|
||||
decoder = WatermarkDecoder('bytes', 136)
|
||||
watermark = decoder.decode(bgr, 'dwtDct')
|
||||
try:
|
||||
dec = watermark.decode('utf-8')
|
||||
except:
|
||||
dec = "null"
|
||||
print(dec)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(testit)
|
|
@ -1,147 +0,0 @@
|
|||
import os, sys
|
||||
import numpy as np
|
||||
import scann
|
||||
import argparse
|
||||
import glob
|
||||
from multiprocessing import cpu_count
|
||||
from tqdm import tqdm
|
||||
|
||||
from ldm.util import parallel_data_prefetch
|
||||
|
||||
|
||||
def search_bruteforce(searcher):
|
||||
return searcher.score_brute_force().build()
|
||||
|
||||
|
||||
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
|
||||
partioning_trainsize, num_leaves, num_leaves_to_search):
|
||||
return searcher.tree(num_leaves=num_leaves,
|
||||
num_leaves_to_search=num_leaves_to_search,
|
||||
training_sample_size=partioning_trainsize). \
|
||||
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
|
||||
|
||||
|
||||
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
|
||||
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
|
||||
reorder_k).build()
|
||||
|
||||
def load_datapool(dpath):
|
||||
|
||||
|
||||
def load_single_file(saved_embeddings):
|
||||
compressed = np.load(saved_embeddings)
|
||||
database = {key: compressed[key] for key in compressed.files}
|
||||
return database
|
||||
|
||||
def load_multi_files(data_archive):
|
||||
database = {key: [] for key in data_archive[0].files}
|
||||
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
|
||||
for key in d.files:
|
||||
database[key].append(d[key])
|
||||
|
||||
return database
|
||||
|
||||
print(f'Load saved patch embedding from "{dpath}"')
|
||||
file_content = glob.glob(os.path.join(dpath, '*.npz'))
|
||||
|
||||
if len(file_content) == 1:
|
||||
data_pool = load_single_file(file_content[0])
|
||||
elif len(file_content) > 1:
|
||||
data = [np.load(f) for f in file_content]
|
||||
prefetched_data = parallel_data_prefetch(load_multi_files, data,
|
||||
n_proc=min(len(data), cpu_count()), target_data_type='dict')
|
||||
|
||||
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
|
||||
else:
|
||||
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
|
||||
|
||||
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
|
||||
return data_pool
|
||||
|
||||
|
||||
def train_searcher(opt,
|
||||
metric='dot_product',
|
||||
partioning_trainsize=None,
|
||||
reorder_k=None,
|
||||
# todo tune
|
||||
aiq_thld=0.2,
|
||||
dims_per_block=2,
|
||||
num_leaves=None,
|
||||
num_leaves_to_search=None,):
|
||||
|
||||
data_pool = load_datapool(opt.database)
|
||||
k = opt.knn
|
||||
|
||||
if not reorder_k:
|
||||
reorder_k = 2 * k
|
||||
|
||||
# normalize
|
||||
# embeddings =
|
||||
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
|
||||
pool_size = data_pool['embedding'].shape[0]
|
||||
|
||||
print(*(['#'] * 100))
|
||||
print('Initializing scaNN searcher with the following values:')
|
||||
print(f'k: {k}')
|
||||
print(f'metric: {metric}')
|
||||
print(f'reorder_k: {reorder_k}')
|
||||
print(f'anisotropic_quantization_threshold: {aiq_thld}')
|
||||
print(f'dims_per_block: {dims_per_block}')
|
||||
print(*(['#'] * 100))
|
||||
print('Start training searcher....')
|
||||
print(f'N samples in pool is {pool_size}')
|
||||
|
||||
# this reflects the recommended design choices proposed at
|
||||
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
|
||||
if pool_size < 2e4:
|
||||
print('Using brute force search.')
|
||||
searcher = search_bruteforce(searcher)
|
||||
elif 2e4 <= pool_size and pool_size < 1e5:
|
||||
print('Using asymmetric hashing search and reordering.')
|
||||
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
||||
else:
|
||||
print('Using using partioning, asymmetric hashing search and reordering.')
|
||||
|
||||
if not partioning_trainsize:
|
||||
partioning_trainsize = data_pool['embedding'].shape[0] // 10
|
||||
if not num_leaves:
|
||||
num_leaves = int(np.sqrt(pool_size))
|
||||
|
||||
if not num_leaves_to_search:
|
||||
num_leaves_to_search = max(num_leaves // 20, 1)
|
||||
|
||||
print('Partitioning params:')
|
||||
print(f'num_leaves: {num_leaves}')
|
||||
print(f'num_leaves_to_search: {num_leaves_to_search}')
|
||||
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
||||
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
|
||||
partioning_trainsize, num_leaves, num_leaves_to_search)
|
||||
|
||||
print('Finish training searcher')
|
||||
searcher_savedir = opt.target_path
|
||||
os.makedirs(searcher_savedir, exist_ok=True)
|
||||
searcher.serialize(searcher_savedir)
|
||||
print(f'Saved trained searcher under "{searcher_savedir}"')
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.path.append(os.getcwd())
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--database',
|
||||
'-d',
|
||||
default='data/rdm/retrieval_databases/openimages',
|
||||
type=str,
|
||||
help='path to folder containing the clip feature of the database')
|
||||
parser.add_argument('--target_path',
|
||||
'-t',
|
||||
default='data/rdm/searchers/openimages',
|
||||
type=str,
|
||||
help='path to the target folder where the searcher shall be stored.')
|
||||
parser.add_argument('--knn',
|
||||
'-k',
|
||||
default=20,
|
||||
type=int,
|
||||
help='number of nearest neighbors, for which the searcher shall be optimized')
|
||||
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
train_searcher(opt,)
|
|
@ -1,344 +0,0 @@
|
|||
import argparse, os, sys, glob
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from imwatermark import WatermarkEncoder
|
||||
from itertools import islice
|
||||
from einops import rearrange
|
||||
from torchvision.utils import make_grid
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
# load safety model
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def put_watermark(img, wm_encoder=None):
|
||||
if wm_encoder is not None:
|
||||
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
img = wm_encoder.encode(img, 'dwtDct')
|
||||
img = Image.fromarray(img[:, :, ::-1])
|
||||
return img
|
||||
|
||||
|
||||
def load_replacement(x):
|
||||
try:
|
||||
hwc = x.shape
|
||||
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
|
||||
y = (np.array(y)/255.0).astype(x.dtype)
|
||||
assert y.shape == x.shape
|
||||
return y
|
||||
except Exception:
|
||||
return x
|
||||
|
||||
|
||||
def check_safety(x_image):
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
assert x_checked_image.shape[0] == len(has_nsfw_concept)
|
||||
for i in range(len(has_nsfw_concept)):
|
||||
if has_nsfw_concept[i]:
|
||||
x_checked_image[i] = load_replacement(x_checked_image[i])
|
||||
return x_checked_image, has_nsfw_concept
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a virus monster playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
default="outputs/txt2img-samples"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action='store_true',
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_save",
|
||||
action='store_true',
|
||||
help="do not save individual samples. For speed measurements.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plms",
|
||||
action='store_true',
|
||||
help="use plms sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--laion400m",
|
||||
action='store_true',
|
||||
help="uses the LAION400M model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fixed_code",
|
||||
action='store_true',
|
||||
help="if enabled, uses the same starting code across samples ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iter",
|
||||
type=int,
|
||||
default=2,
|
||||
help="sample this often",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--H",
|
||||
type=int,
|
||||
default=512,
|
||||
help="image height, in pixel space",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--W",
|
||||
type=int,
|
||||
default=512,
|
||||
help="image width, in pixel space",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--C",
|
||||
type=int,
|
||||
default=4,
|
||||
help="latent channels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--f",
|
||||
type=int,
|
||||
default=8,
|
||||
help="downsampling factor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_samples",
|
||||
type=int,
|
||||
default=3,
|
||||
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_rows",
|
||||
type=int,
|
||||
default=0,
|
||||
help="rows in the grid (default: n_samples)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
help="if specified, load prompts from this file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="configs/stable-diffusion/v1-inference.yaml",
|
||||
help="path to config which constructs model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||
help="path to checkpoint of model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed (for reproducible sampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
help="evaluate at this precision",
|
||||
choices=["full", "autocast"],
|
||||
default="autocast"
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
if opt.laion400m:
|
||||
print("Falling back to LAION 400M model...")
|
||||
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
||||
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
|
||||
opt.outdir = "outputs/txt2img-samples-laion400m"
|
||||
|
||||
seed_everything(opt.seed)
|
||||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = model.to(device)
|
||||
|
||||
if opt.plms:
|
||||
sampler = PLMSSampler(model)
|
||||
else:
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
|
||||
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
|
||||
wm = "StableDiffusionV1"
|
||||
wm_encoder = WatermarkEncoder()
|
||||
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
prompt = opt.prompt
|
||||
assert prompt is not None
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = list(chunk(data, batch_size))
|
||||
|
||||
sample_path = os.path.join(outpath, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
start_code = None
|
||||
if opt.fixed_code:
|
||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||
|
||||
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||
conditioning=c,
|
||||
batch_size=opt.n_samples,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt.ddim_eta,
|
||||
x_T=start_code)
|
||||
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
|
||||
|
||||
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
||||
|
||||
if not opt.skip_save:
|
||||
for x_sample in x_checked_image_torch:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
img = Image.fromarray(x_sample.astype(np.uint8))
|
||||
img = put_watermark(img, wm_encoder)
|
||||
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
|
||||
base_count += 1
|
||||
|
||||
if not opt.skip_grid:
|
||||
all_samples.append(x_checked_image_torch)
|
||||
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
img = Image.fromarray(grid.astype(np.uint8))
|
||||
img = put_watermark(img, wm_encoder)
|
||||
img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
|
||||
toc = time.time()
|
||||
|
||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||
f" \nEnjoy.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,13 +0,0 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='latent-diffusion',
|
||||
version='0.0.1',
|
||||
description='',
|
||||
packages=find_packages(),
|
||||
install_requires=[
|
||||
'torch',
|
||||
'numpy',
|
||||
'tqdm',
|
||||
],
|
||||
)
|
|
@ -1,4 +0,0 @@
|
|||
HF_DATASETS_OFFLINE=1
|
||||
TRANSFORMERS_OFFLINE=1
|
||||
|
||||
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml
|
Loading…
Reference in New Issue