diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index 06459bfe5..c12177c36 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -54,14 +54,14 @@ pip install -r requirements.txt && pip install . > The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future. ## Dataset -The DataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/), +The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/), you should the change the `data.file_path` in the `config/train_colossalai.yaml` ## Training -we provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml`, `train_ddp.yaml` +We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml` -for example, you can run the training from colossalai by +For example, you can run the training from colossalai by ``` python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml ``` @@ -69,13 +69,25 @@ python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml - you can change the `--logdir` the save the log information and the last checkpoint ### Training config -you can change the trainging config in the yaml file +You can change the trainging config in the yaml file - accelerator: acceleratortype, default 'gpu' - devices: device number used for training, default 4 - max_epochs: max training epochs - precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai +## Example + +### Training on cifar10 + +We provide the finetuning example on CIFAR10 dataset + +You can run by config `train_colossalai_cifar10.yaml` +``` +python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml +``` + + ## Comments diff --git a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml new file mode 100644 index 000000000..63b9d1c01 --- /dev/null +++ b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml @@ -0,0 +1,123 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: txt + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1.e-4 ] + f_min: [ 1.e-10 ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin' + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin' + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + params: + use_fp16: True + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 4 + train: + target: ldm.data.cifar10.hf_dataset + params: + name: cifar10 + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + +lightning: + trainer: + accelerator: 'gpu' + devices: 2 + log_gpu_memory: all + max_epochs: 2 + precision: 16 + auto_select_gpus: False + strategy: + target: pytorch_lightning.strategies.ColossalAIStrategy + params: + use_chunk: False + enable_distributed_storage: True, + placement_policy: cuda + force_outputs_fp32: False + + log_every_n_steps: 2 + logger: True + default_root_dir: "/tmp/diff_log/" + profiler: pytorch + + logger_config: + wandb: + target: pytorch_lightning.loggers.WandbLogger + params: + name: nowname + save_dir: "/tmp/diff_log/" + offline: opt.debug + id: nowname \ No newline at end of file diff --git a/examples/images/diffusion/environment.yaml b/examples/images/diffusion/environment.yaml index 79b706b83..59baa3c76 100644 --- a/examples/images/diffusion/environment.yaml +++ b/examples/images/diffusion/environment.yaml @@ -11,6 +11,7 @@ dependencies: - numpy=1.19.2 - pip: - albumentations==0.4.3 + - datasets - diffusers - opencv-python==4.6.0.66 - pudb==2019.2 diff --git a/examples/images/diffusion/ldm/data/cifar10.py b/examples/images/diffusion/ldm/data/cifar10.py new file mode 100644 index 000000000..53cd61263 --- /dev/null +++ b/examples/images/diffusion/ldm/data/cifar10.py @@ -0,0 +1,184 @@ +from typing import Dict +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename) + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +def hf_dataset( + name, + image_transforms=[], + image_column="img", + label_column="label", + text_column="txt", + split='train', + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(name, split=split) + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + tform = transforms.Compose(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + + label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + + processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] + + return processed + + ds.set_transform(pre_process) + return ds + +class TextOnly(Dataset): + def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): + """Returns only captions with dummy images""" + self.output_size = output_size + self.image_key = image_key + self.caption_key = caption_key + if isinstance(captions, Path): + self.captions = self._load_caption_file(captions) + else: + self.captions = captions + + if n_gpus > 1: + # hack to make sure that all the captions appear on each gpu + repeated = [n_gpus*[x] for x in self.captions] + self.captions = [] + [self.captions.extend(x) for x in repeated] + + def __len__(self): + return len(self.captions) + + def __getitem__(self, index): + dummy_im = torch.zeros(3, self.output_size, self.output_size) + dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + return {self.image_key: dummy_im, self.caption_key: self.captions[index]} + + def _load_caption_file(self, filename): + with open(filename, 'rt') as f: + captions = f.readlines() + return [x.strip('\n') for x in captions] \ No newline at end of file diff --git a/examples/images/diffusion/requirements.txt b/examples/images/diffusion/requirements.txt index f5c9ee70a..54bc00029 100644 --- a/examples/images/diffusion/requirements.txt +++ b/examples/images/diffusion/requirements.txt @@ -1,11 +1,12 @@ albumentations==0.4.3 diffusers -opencv-python==4.1.2.30 pudb==2019.2 +datasets invisible-watermark imageio==2.9.0 imageio-ffmpeg==0.4.2 omegaconf==2.1.1 +multiprocess test-tube>=0.7.5 streamlit>=0.73.1 einops==0.3.0