mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] add cifar10 for diffusion (#1907)
parent
14a0b18305
commit
11ee8ae478
|
@ -1,5 +1,4 @@
|
||||||
# Handson 6: Acceleration of Stable Diffusion
|
# Stable Diffusion with Colossal-AI
|
||||||
|
|
||||||
*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
|
*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
|
||||||
fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
|
fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
|
||||||
|
|
||||||
|
@ -55,28 +54,40 @@ pip install -r requirements.txt && pip install .
|
||||||
> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future.
|
> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future.
|
||||||
|
|
||||||
## Dataset
|
## Dataset
|
||||||
The DataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
|
The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
|
||||||
you should the change the `data.file_path` in the `config/train_colossalai.yaml`
|
you should the change the `data.file_path` in the `config/train_colossalai.yaml`
|
||||||
|
|
||||||
## Training
|
## Training
|
||||||
|
|
||||||
we provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml`, `train_ddp.yaml`
|
We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml`
|
||||||
|
|
||||||
for example, you can run the training from colossalai by
|
For example, you can run the training from colossalai by
|
||||||
```
|
```
|
||||||
python main.py --logdir /tmp -t --postfix test -b config/train_colossalai.yaml
|
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
- you can change the `--logdir` the save the log information and the last checkpoint
|
- you can change the `--logdir` the save the log information and the last checkpoint
|
||||||
|
|
||||||
### Training config
|
### Training config
|
||||||
you can change the trainging config in the yaml file
|
You can change the trainging config in the yaml file
|
||||||
|
|
||||||
- accelerator: acceleratortype, default 'gpu'
|
- accelerator: acceleratortype, default 'gpu'
|
||||||
- devices: device number used for training, default 4
|
- devices: device number used for training, default 4
|
||||||
- max_epochs: max training epochs
|
- max_epochs: max training epochs
|
||||||
- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
|
- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
### Training on cifar10
|
||||||
|
|
||||||
|
We provide the finetuning example on CIFAR10 dataset
|
||||||
|
|
||||||
|
You can run by config `train_colossalai_cifar10.yaml`
|
||||||
|
```
|
||||||
|
python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Comments
|
## Comments
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: txt
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1.e-4 ]
|
||||||
|
f_min: [ 1.e-10 ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: False
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
use_fp16: True
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
train:
|
||||||
|
target: ldm.data.cifar10.hf_dataset
|
||||||
|
params:
|
||||||
|
name: cifar10
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
- target: torchvision.transforms.RandomHorizontalFlip
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
trainer:
|
||||||
|
accelerator: 'gpu'
|
||||||
|
devices: 2
|
||||||
|
log_gpu_memory: all
|
||||||
|
max_epochs: 2
|
||||||
|
precision: 16
|
||||||
|
auto_select_gpus: False
|
||||||
|
strategy:
|
||||||
|
target: pytorch_lightning.strategies.ColossalAIStrategy
|
||||||
|
params:
|
||||||
|
use_chunk: False
|
||||||
|
enable_distributed_storage: True,
|
||||||
|
placement_policy: cuda
|
||||||
|
force_outputs_fp32: False
|
||||||
|
|
||||||
|
log_every_n_steps: 2
|
||||||
|
logger: True
|
||||||
|
default_root_dir: "/tmp/diff_log/"
|
||||||
|
profiler: pytorch
|
||||||
|
|
||||||
|
logger_config:
|
||||||
|
wandb:
|
||||||
|
target: pytorch_lightning.loggers.WandbLogger
|
||||||
|
params:
|
||||||
|
name: nowname
|
||||||
|
save_dir: "/tmp/diff_log/"
|
||||||
|
offline: opt.debug
|
||||||
|
id: nowname
|
|
@ -11,20 +11,21 @@ dependencies:
|
||||||
- numpy=1.19.2
|
- numpy=1.19.2
|
||||||
- pip:
|
- pip:
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
|
- datasets
|
||||||
- diffusers
|
- diffusers
|
||||||
- opencv-python==4.6.0.66
|
- opencv-python==4.6.0.66
|
||||||
- pudb==2019.2
|
- pudb==2019.2
|
||||||
- invisible-watermark
|
- invisible-watermark
|
||||||
- imageio==2.9.0
|
- imageio==2.9.0
|
||||||
- imageio-ffmpeg==0.4.2
|
- imageio-ffmpeg==0.4.2
|
||||||
- pytorch-lightning==1.4.2
|
- pytorch-lightning==1.8.0
|
||||||
- omegaconf==2.1.1
|
- omegaconf==2.1.1
|
||||||
- test-tube>=0.7.5
|
- test-tube>=0.7.5
|
||||||
- streamlit>=0.73.1
|
- streamlit>=0.73.1
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
- transformers==4.19.2
|
- transformers==4.19.2
|
||||||
- torchmetrics==0.6.0
|
- torchmetrics==0.7.0
|
||||||
- kornia==0.6
|
- kornia==0.6
|
||||||
- prefetch_generator
|
- prefetch_generator
|
||||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
|
|
|
@ -0,0 +1,184 @@
|
||||||
|
from typing import Dict
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import DictConfig, ListConfig
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from einops import rearrange
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
def make_multi_folder_data(paths, caption_files=None, **kwargs):
|
||||||
|
"""Make a concat dataset from multiple folders
|
||||||
|
Don't suport captions yet
|
||||||
|
If paths is a list, that's ok, if it's a Dict interpret it as:
|
||||||
|
k=folder v=n_times to repeat that
|
||||||
|
"""
|
||||||
|
list_of_paths = []
|
||||||
|
if isinstance(paths, (Dict, DictConfig)):
|
||||||
|
assert caption_files is None, \
|
||||||
|
"Caption files not yet supported for repeats"
|
||||||
|
for folder_path, repeats in paths.items():
|
||||||
|
list_of_paths.extend([folder_path]*repeats)
|
||||||
|
paths = list_of_paths
|
||||||
|
|
||||||
|
if caption_files is not None:
|
||||||
|
datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)]
|
||||||
|
else:
|
||||||
|
datasets = [FolderData(p, **kwargs) for p in paths]
|
||||||
|
return torch.utils.data.ConcatDataset(datasets)
|
||||||
|
|
||||||
|
class FolderData(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
root_dir,
|
||||||
|
caption_file=None,
|
||||||
|
image_transforms=[],
|
||||||
|
ext="jpg",
|
||||||
|
default_caption="",
|
||||||
|
postprocess=None,
|
||||||
|
return_paths=False,
|
||||||
|
) -> None:
|
||||||
|
"""Create a dataset from a folder of images.
|
||||||
|
If you pass in a root directory it will be searched for images
|
||||||
|
ending in ext (ext can be a list)
|
||||||
|
"""
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.default_caption = default_caption
|
||||||
|
self.return_paths = return_paths
|
||||||
|
if isinstance(postprocess, DictConfig):
|
||||||
|
postprocess = instantiate_from_config(postprocess)
|
||||||
|
self.postprocess = postprocess
|
||||||
|
if caption_file is not None:
|
||||||
|
with open(caption_file, "rt") as f:
|
||||||
|
ext = Path(caption_file).suffix.lower()
|
||||||
|
if ext == ".json":
|
||||||
|
captions = json.load(f)
|
||||||
|
elif ext == ".jsonl":
|
||||||
|
lines = f.readlines()
|
||||||
|
lines = [json.loads(x) for x in lines]
|
||||||
|
captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognised format: {ext}")
|
||||||
|
self.captions = captions
|
||||||
|
else:
|
||||||
|
self.captions = None
|
||||||
|
|
||||||
|
if not isinstance(ext, (tuple, list, ListConfig)):
|
||||||
|
ext = [ext]
|
||||||
|
|
||||||
|
# Only used if there is no caption file
|
||||||
|
self.paths = []
|
||||||
|
for e in ext:
|
||||||
|
self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
|
||||||
|
if isinstance(image_transforms, ListConfig):
|
||||||
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||||
|
image_transforms.extend([transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||||
|
image_transforms = transforms.Compose(image_transforms)
|
||||||
|
self.tform = image_transforms
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.captions is not None:
|
||||||
|
return len(self.captions.keys())
|
||||||
|
else:
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
data = {}
|
||||||
|
if self.captions is not None:
|
||||||
|
chosen = list(self.captions.keys())[index]
|
||||||
|
caption = self.captions.get(chosen, None)
|
||||||
|
if caption is None:
|
||||||
|
caption = self.default_caption
|
||||||
|
filename = self.root_dir/chosen
|
||||||
|
else:
|
||||||
|
filename = self.paths[index]
|
||||||
|
|
||||||
|
if self.return_paths:
|
||||||
|
data["path"] = str(filename)
|
||||||
|
|
||||||
|
im = Image.open(filename)
|
||||||
|
im = self.process_im(im)
|
||||||
|
data["image"] = im
|
||||||
|
|
||||||
|
if self.captions is not None:
|
||||||
|
data["txt"] = caption
|
||||||
|
else:
|
||||||
|
data["txt"] = self.default_caption
|
||||||
|
|
||||||
|
if self.postprocess is not None:
|
||||||
|
data = self.postprocess(data)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def process_im(self, im):
|
||||||
|
im = im.convert("RGB")
|
||||||
|
return self.tform(im)
|
||||||
|
|
||||||
|
def hf_dataset(
|
||||||
|
name,
|
||||||
|
image_transforms=[],
|
||||||
|
image_column="img",
|
||||||
|
label_column="label",
|
||||||
|
text_column="txt",
|
||||||
|
split='train',
|
||||||
|
image_key='image',
|
||||||
|
caption_key='txt',
|
||||||
|
):
|
||||||
|
"""Make huggingface dataset with appropriate list of transforms applied
|
||||||
|
"""
|
||||||
|
ds = load_dataset(name, split=split)
|
||||||
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||||
|
image_transforms.extend([transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||||
|
tform = transforms.Compose(image_transforms)
|
||||||
|
|
||||||
|
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
|
||||||
|
assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}"
|
||||||
|
|
||||||
|
def pre_process(examples):
|
||||||
|
processed = {}
|
||||||
|
processed[image_key] = [tform(im) for im in examples[image_column]]
|
||||||
|
|
||||||
|
label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}
|
||||||
|
|
||||||
|
processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]]
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
ds.set_transform(pre_process)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
class TextOnly(Dataset):
|
||||||
|
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
|
||||||
|
"""Returns only captions with dummy images"""
|
||||||
|
self.output_size = output_size
|
||||||
|
self.image_key = image_key
|
||||||
|
self.caption_key = caption_key
|
||||||
|
if isinstance(captions, Path):
|
||||||
|
self.captions = self._load_caption_file(captions)
|
||||||
|
else:
|
||||||
|
self.captions = captions
|
||||||
|
|
||||||
|
if n_gpus > 1:
|
||||||
|
# hack to make sure that all the captions appear on each gpu
|
||||||
|
repeated = [n_gpus*[x] for x in self.captions]
|
||||||
|
self.captions = []
|
||||||
|
[self.captions.extend(x) for x in repeated]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.captions)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
dummy_im = torch.zeros(3, self.output_size, self.output_size)
|
||||||
|
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
|
||||||
|
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
|
||||||
|
|
||||||
|
def _load_caption_file(self, filename):
|
||||||
|
with open(filename, 'rt') as f:
|
||||||
|
captions = f.readlines()
|
||||||
|
return [x.strip('\n') for x in captions]
|
|
@ -1,11 +1,12 @@
|
||||||
albumentations==0.4.3
|
albumentations==0.4.3
|
||||||
diffusers
|
diffusers
|
||||||
opencv-python==4.1.2.30
|
|
||||||
pudb==2019.2
|
pudb==2019.2
|
||||||
|
datasets
|
||||||
invisible-watermark
|
invisible-watermark
|
||||||
imageio==2.9.0
|
imageio==2.9.0
|
||||||
imageio-ffmpeg==0.4.2
|
imageio-ffmpeg==0.4.2
|
||||||
omegaconf==2.1.1
|
omegaconf==2.1.1
|
||||||
|
multiprocess
|
||||||
test-tube>=0.7.5
|
test-tube>=0.7.5
|
||||||
streamlit>=0.73.1
|
streamlit>=0.73.1
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
|
|
Loading…
Reference in New Issue