diff --git a/applications/stable-diffusion/text_img2img/README.md b/applications/stable-diffusion/text_img2img/README.md new file mode 100644 index 000000000..76e4fd386 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/README.md @@ -0,0 +1,220 @@ +# Text-to-Image/Image-to-Image: Stable Diffusion with Colossal-AI + +Acceleration of AIGC (AI-Generated Content) models such as [Text-to-image model](https://huggingface.co/CompVis/stable-diffusion-v1-4) and Instruct-Pix2Pix such as [Image-to-Image Model](https://huggingface.co/docs/diffusers/training/instructpix2pix) + +More details can be found in our [blog of Stable Diffusion v1](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) and [blog of Stable Diffusion v2](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0). + + +## Installation + +### Option #1: Install from source +#### Step 1: Requirements + +To begin with, make sure your operating system has the cuda version suitable for this exciting training session, which is cuda11.6/11.8. For your convience, we have set up the rest of packages here. You can create and activate a suitable [conda](https://conda.io/) environment named `ldm` : + +``` +conda env create -f environment.yaml +conda activate ldm +``` + +You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running: + +``` +conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch +pip install transformers diffusers invisible-watermark +``` + +#### Step 2: Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website + +You can install the latest version (0.2.7) from our official website or from source. Notice that the suitable version for this training is colossalai(0.2.5), which stands for torch(1.12.1). + +##### Download suggested version for this training + +``` +pip install colossalai==0.2.5 +``` + +##### Download the latest version from pip for latest torch version + +``` +pip install colossalai +``` + +##### From source: + +``` +git clone https://github.com/hpcaitech/ColossalAI.git +cd ColossalAI + +# install colossalai +CUDA_EXT=1 pip install . +``` + +#### Step 3: Accelerate with flash attention by xformers (Optional) + +Notice that xformers will accelerate the training process at the cost of extra disk space. The suitable version of xformers for this training process is 0.0.12, which can be downloaded directly via pip. For more release versions, feel free to check its official website: [XFormers](https://pypi.org/project/xformers/) + +``` +pip install xformers==0.0.12 +``` + + +### stable-diffusion-model (Recommended) + +For example: You can follow this [link] (https://huggingface.co/CompVis/stable-diffusion-v1-4) to download your model. In our training example, we choose Stable-Diffusion-v1-4 as a demo example to who how to train our model. + +### stable-diffusion-v1-4 + +``` +git lfs install +git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 +``` + + +## Dataset + +The dataSet is from [Dataset-HuggingFace](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads). In our examples, we choose lambdalabs/pokemon-blip-captions as a demo example for text-to-image, and fusing/instructpix2pix-1000-samples as our data to train instruct-pix2pix model. You can also create your own dataset, but make sure your data set matches with formats in this [website] (https://huggingface.co/docs/diffusers/training/create_dataset). + +## Training + +We provide the script `trainer_no_colossalai_text_to_image.sh`, `trainer_no_colossalai_image_to_image.sh` and `trainer_no_colossalai_dreambooth.sh` to run the training task without colossalai. Meanwhile, we also provided script called `trainer_with_colossalai_text_to_image.sh` to train text-to-image, and `trainer_with_colossalai_image_to_image.sh` to train instruct-pix2pix models using colossalai. We also provided `trainer_with_colossalai_dreambooth.sh` to train dreambooth model using colossalai. The following is a command demo: +``` +#!/bin/bash + +# export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export DATASET_ID="fusing/instructpix2pix-1000-samples" + +torchrun --nproc_per_node 4 stable_diffusion_colossalai_trainer.py \ + --mixed_precision="fp16" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_ID \ + --resolution=256 --random_flip \ + --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \ + --max_train_steps=200 \ + --checkpointing_steps=5000 --checkpoints_total_limit=1 \ + --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ + --conditioning_dropout_prob=0.05 \ + --mixed_precision=fp16 \ + --seed=42 \ + --plugin="gemini" \ + --placement="cuda" \ + --task_type="image_to_image" \ + --output_dir="instruct_pix2pix" + +``` +Also, if you want to LoRA to fine-tune your model, you can use the following command line to yse LoRA to fine-tune your model. You only need to add --use_lora in your arguments. If you are not familar with LoRA, you can check this [website](https://huggingface.co/docs/diffusers/training/lora). There is a simple example below to demonstarte how to launch training. + +``` +#!/bin/bash + +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export dataset_name="lambdalabs/pokemon-blip-captions" + +torchrun --nproc_per_node 4 stable_diffusion_colossalai_trainer.py \ + --mixed_precision="fp16" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$dataset_name \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=100 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" \ + --plugin="gemini" \ + --placement="cuda" \ + --task_type="text_to_image" \ + --use_lora + +``` + +Also, if you want to train your dreambooth model, make sure you have correct dataset to be prepared. In our case, you need to firstly to run download_dataset_dreambooth.py to download data, then use the following command line to run training script: + +``` +torchrun --nproc_per_node 4 --standalone stable_diffusion_colossalai_trainer.py \ + --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ + --instance_data_dir="/home/lclcq/ColossalAI/applications/stable-diffusion/text_img2img/dog" \ + --output_dir="./weight_output" \ + --instance_prompt="a picture of a dog" \ + --resolution=512 \ + --plugin="gemini" \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --placement="cuda" \ + --task_type="dreambooth" +``` + + +### Training config + +You can change the training config in the yaml file + +- nproc_per_node: device number used for training, default = 4 +- precision: the precision type used in training, default = 16 (fp16), you must use fp16 if you want to apply colossalai +- plugin:we support the following training stategies: 'torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', and we choose gemini as our demonstration in our example. +- placement_policy: the training strategy supported by Colossal AI, default = 'cuda', which refers to loading all the parameters into cuda memory. On the other hand, 'cpu' refers to 'cpu offload' strategy while 'auto' enables 'Gemini', both featured by Colossal AI. +- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai) +- Also, for more arguments info, please check parse_arguments.py file in the current directory. + +### Inference config +After training, you can use the following command line to test your inference result: +``` +python text_to_image_colossalai.py --validation_prompts "a person is walking on the Moon" --saved_unet_path /path/to/unet_trained_model.bin +``` +The following is an example after running command line above, and the picture was generated after training diffusion models using our script stable_diffusion_colossalai_trainer.py. + + +## Invitation to open-source contribution +Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models! + +You may contact us or participate in the following ways: +1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks! +2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md). +3. Join the Colossal-AI community on +[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w), +and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas. +4. Send your official proposal to email contact@hpcaitech.com + +Thanks so much to all of our amazing contributors! + +## Comments + +- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) and [hugging face diffusion](https://github.com/huggingface/diffusers/blob/main/examples). +, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch), +[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion). +Thanks for open-sourcing! + +- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories). + +- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch). + +## BibTeX + +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +@misc{rombach2021highresolution, + title={High-Resolution Image Synthesis with Latent Diffusion Models}, + author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer}, + year={2021}, + eprint={2112.10752}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +@article{dao2022flashattention, + title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + journal={arXiv preprint arXiv:2205.14135}, + year={2022} +} +``` diff --git a/applications/stable-diffusion/text_img2img/download_dataset_dreambooth.py b/applications/stable-diffusion/text_img2img/download_dataset_dreambooth.py new file mode 100644 index 000000000..a79c40781 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/download_dataset_dreambooth.py @@ -0,0 +1,9 @@ +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) + diff --git a/applications/stable-diffusion/text_img2img/dreambooth_utils.py b/applications/stable-diffusion/text_img2img/dreambooth_utils.py new file mode 100644 index 000000000..ccdc08263 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/dreambooth_utils.py @@ -0,0 +1,110 @@ +import PIL +from PIL import Image +from pathlib import Path +from typing import Optional + +from torch.utils.data import Dataset +from torchvision import transforms + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" \ No newline at end of file diff --git a/applications/stable-diffusion/text_img2img/inference.py b/applications/stable-diffusion/text_img2img/inference.py new file mode 100644 index 000000000..02f658ab4 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/inference.py @@ -0,0 +1,124 @@ +import argparse +import PIL +import requests + +import torch +from torch import nn +from diffusers import DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, StableDiffusionInstructPix2PixPipeline +from transformers import CLIPTextModel, CLIPTokenizer +from colossalai.booster import Booster + +def download_image(url): + image = PIL.Image.open(requests.get(url, stream=True).raw) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + +def main(): + parser = argparse.ArgumentParser(description="Simple example of a inference script.") + parser.add_argument( + "--model_id", + type=str, + default=None, + help=("your trained model id or model name") + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument("--unet_saved_path", type=str, default=None, help=("path of your trained unet model")) + parser.add_argument( + "--val_image_url", + type=str, + default = None, + help=("the url of your test image") + ) + parser.add_argument('--task_type', + type=str, + default='text_to_image', + choices=['text_to_image', 'image_to_image'], + help="plugin to use") + + args = parser.parse_args() + + model_id = args.model_id + assert args.validation_prompts is not None, "have to provide a prompt for this inference file" + if args.task_type == "image_to_image": + assert args.val_image_url is not None, "the image url has to be provided" + + tokenizer = CLIPTokenizer.from_pretrained( + model_id, subfolder="tokenizer", revision=None + ) + + text_encoder = CLIPTextModel.from_pretrained( + model_id, subfolder="text_encoder", revision=None + ) + vae = AutoencoderKL.from_pretrained( + model_id, subfolder="vae", revision=None + ) + + unet = UNet2DConditionModel.from_pretrained( + model_id, subfolder="unet", revision=None + ) + + if args.task_type == "image_to_image": + in_channels = 8 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in + + if args.unet_saved_path is not None: + print("loading trained model from {}".format(args.unet_saved_path)) + unet.load_state_dict(torch.load(args.unet_saved_path)) + + if args.task_type == "text_to_image": + print("use text_to_image pipeline") + pipeline = StableDiffusionPipeline.from_pretrained( + model_id, + text_encoder=text_encoder, + vae=vae, + unet=unet, + ) + else: + print("use image_to_image pipeline") + pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( + model_id, + text_encoder=text_encoder, + vae=vae, + unet=unet, + ) + + + pipe = pipeline.to("cuda") + pipe.enable_attention_slicing() + prompt = args.validation_prompts + if args.task_type == "text_to_image": + print("get result from text_to_image model ...") + image = pipe(prompt).images[0] + image.save("text_to_image_example.png") + else: + print("get result from image_to_image model ...") + original_image = download_image(args.val_image_url) + original_image.save("original_image.png") + image = pipeline(prompt, image=original_image, num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + ).images[0] + + image.save("image_to_image_example.png") + + + +if __name__ == "__main__": + main() + diff --git a/applications/stable-diffusion/text_img2img/parse_arguments.py b/applications/stable-diffusion/text_img2img/parse_arguments.py new file mode 100644 index 000000000..d25396862 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/parse_arguments.py @@ -0,0 +1,401 @@ +import os +import argparse +import logging + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--externel_unet_path", + type=str, + default=None, + required=False, + help="Path to the externel unet model.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help="A folder containing the training data of instance images.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=("Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt."), + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--instance_prompt", + type=str, + default="a photo of sks dog", + required=False, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--edit_prompt_column", + type=str, + default="edit_prompt", + help="The column of the dataset containing the edit instruction.", + ) + parser.add_argument( + "--edited_image_column", + type=str, + default="edited_image", + help="The column of the dataset containing the edited image.", + ) + parser.add_argument( + "--original_image_column", + type=str, + default="input_image", + help="The column of the dataset containing the original image on which edits where made.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--use_lora", + default=False, + action="store_true", + help=( + "Whether to use LoRA to fine-tune your model" + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--conditioning_dropout_prob", + type=float, + default=None, + help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.", + ) + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + help="plugin to use") + parser.add_argument( + "--placement", + type=str, + default="cpu", + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument('--task_type', + type=str, + default='text_to_image', + choices=['text_to_image', 'image_to_image', 'dreambooth'], + help="plugin to use") + + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.task_type == "dreambooth": + if args.instance_data_dir is None: + raise ValueError("need to provide instance_data_dir for dreambooth training task") + elif args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + diff --git a/applications/stable-diffusion/text_img2img/scripts_run/inference.sh b/applications/stable-diffusion/text_img2img/scripts_run/inference.sh new file mode 100644 index 000000000..9c959d99b --- /dev/null +++ b/applications/stable-diffusion/text_img2img/scripts_run/inference.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +python inference.py --validation_prompts "comparison between beijing and shanghai" \ + --unet_saved_path ./sd-pokemon-model/diffusion_pytorch_model.bin \ + --model_id "CompVis/stable-diffusion-v1-4" \ + --task_type "text_to_image" \ + +python inference.py --validation_prompts "turn the color of the mountain into yellow" \ + --val_image_url https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png \ + --unet_saved_path ./instruct_pix2pix/diffusion_pytorch_model.bin \ + --model_id "CompVis/stable-diffusion-v1-4" \ + --task_type "image_to_image" \ \ No newline at end of file diff --git a/applications/stable-diffusion/text_img2img/scripts_run/trainer_colossalai_dreambooth.sh b/applications/stable-diffusion/text_img2img/scripts_run/trainer_colossalai_dreambooth.sh new file mode 100644 index 000000000..07cdeed67 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/scripts_run/trainer_colossalai_dreambooth.sh @@ -0,0 +1,19 @@ +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 + +torchrun --nproc_per_node 4 --standalone stable_diffusion_colossalai_trainer.py \ + --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ + --instance_data_dir="/home/lclcq/ColossalAI/applications/stable-diffusion/text_img2img/dog" \ + --output_dir="./weight_output" \ + --instance_prompt="a picture of a dog" \ + --resolution=512 \ + --plugin="gemini" \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --placement="cuda" \ + --task_type="dreambooth" + \ No newline at end of file diff --git a/applications/stable-diffusion/text_img2img/scripts_run/trainer_colossalai_image_to_image.sh b/applications/stable-diffusion/text_img2img/scripts_run/trainer_colossalai_image_to_image.sh new file mode 100644 index 000000000..802ff5928 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/scripts_run/trainer_colossalai_image_to_image.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export DATASET_ID="fusing/instructpix2pix-1000-samples" + +torchrun --nproc_per_node 4 stable_diffusion_colossalai_trainer.py \ + --mixed_precision="fp16" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_ID \ + --resolution=256 --random_flip \ + --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \ + --max_train_steps=12000 \ + --checkpointing_steps=5000 --checkpoints_total_limit=1 \ + --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ + --conditioning_dropout_prob=0.05 \ + --mixed_precision=fp16 \ + --seed=42 \ + --plugin="gemini" \ + --placement="cuda" \ + --task_type="image_to_image" \ + --output_dir="instruct_pix2pix" \ No newline at end of file diff --git a/applications/stable-diffusion/text_img2img/scripts_run/trainer_colossalai_text_to_image.sh b/applications/stable-diffusion/text_img2img/scripts_run/trainer_colossalai_text_to_image.sh new file mode 100644 index 000000000..4dfb4359d --- /dev/null +++ b/applications/stable-diffusion/text_img2img/scripts_run/trainer_colossalai_text_to_image.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export dataset_name="lambdalabs/pokemon-blip-captions" + +torchrun --nproc_per_node 4 stable_diffusion_colossalai_trainer.py \ + --mixed_precision="fp16" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$dataset_name \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=12000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" \ + --plugin="gemini" \ + --placement="cuda" \ + --task_type="text_to_image" diff --git a/applications/stable-diffusion/text_img2img/scripts_run/trainer_no_colossalai_dreambooth.sh b/applications/stable-diffusion/text_img2img/scripts_run/trainer_no_colossalai_dreambooth.sh new file mode 100644 index 000000000..436bc57d2 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/scripts_run/trainer_no_colossalai_dreambooth.sh @@ -0,0 +1,28 @@ +# python train_dreambooth.py \ +# --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ +# --instance_data_dir="/home/lclcq/ColossalAI/applications/stable-diffusion/text_img2img/dog" \ +# --output_dir="./weight_output" \ +# --instance_prompt="a photo of a dog" \ +# --resolution=512 \ +# --train_batch_size=1 \ +# --gradient_accumulation_steps=1 \ +# --learning_rate=5e-6 \ +# --lr_scheduler="constant" \ +# --lr_warmup_steps=0 \ +# --num_class_images=200 \ + + +accelerate launch --mixed_precision="fp16" stable_diffusion_trainer.py \ + --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ + --instance_data_dir="/home/lclcq/ColossalAI/applications/stable-diffusion/text_img2img/dog" \ + --output_dir="./weight_output" \ + --instance_prompt="a picture of a dog" \ + --resolution=512 \ + --plugin="gemini" \ + --train_batch_size=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --placement="cuda" \ + --task_type="dreambooth" \ No newline at end of file diff --git a/applications/stable-diffusion/text_img2img/scripts_run/trainer_no_colossalai_image_to_image.sh b/applications/stable-diffusion/text_img2img/scripts_run/trainer_no_colossalai_image_to_image.sh new file mode 100644 index 000000000..5171223e2 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/scripts_run/trainer_no_colossalai_image_to_image.sh @@ -0,0 +1,21 @@ +#!/bin/bash + + +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export DATASET_ID="fusing/instructpix2pix-1000-samples" + +accelerate launch --mixed_precision="fp16" stable_diffusion_trainer.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_ID \ + --resolution=256 --random_flip \ + --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \ + --max_train_steps=12000 \ + --checkpointing_steps=5000 --checkpoints_total_limit=1 \ + --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ + --conditioning_dropout_prob=0.05 \ + --mixed_precision=fp16 \ + --seed=42 \ + --plugin="gemini" \ + --placement="cuda" \ + --task_type="image_to_image" \ + --output_dir="instruct_pix2pix" \ No newline at end of file diff --git a/applications/stable-diffusion/text_img2img/scripts_run/trainer_no_colossalai_text_to_image.sh b/applications/stable-diffusion/text_img2img/scripts_run/trainer_no_colossalai_text_to_image.sh new file mode 100644 index 000000000..7bc4a5794 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/scripts_run/trainer_no_colossalai_text_to_image.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export dataset_name="lambdalabs/pokemon-blip-captions" + +accelerate launch --mixed_precision="fp16" stable_diffusion_trainer.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$dataset_name \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" \ + --task_type="text_to_image" \ No newline at end of file diff --git a/applications/stable-diffusion/text_img2img/stable_diffusion_colossalai_trainer.py b/applications/stable-diffusion/text_img2img/stable_diffusion_colossalai_trainer.py new file mode 100644 index 000000000..139acefa0 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/stable_diffusion_colossalai_trainer.py @@ -0,0 +1,796 @@ +import argparse +import logging +import math +import os +import PIL +import random +import shutil +from pathlib import Path +from typing import Optional + +import accelerate +import datasets +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers import AutoTokenizer, PretrainedConfig +from transformers.utils import ContextManagers +from parse_arguments import parse_args +from dreambooth_utils import DreamBoothDataset, PromptDataset, get_full_repo_name + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.cross_attention import LoRACrossAttnProcessor +from diffusers.loaders import AttnProcsLayers + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext +from colossalai.zero.gemini import get_static_torch_model +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin + + + +if is_wandb_available(): + import wandb + +disable_existing_loggers() +logger = get_dist_logger() + + +def convert_to_np(image, resolution): + image = image.convert("RGB").resize((resolution, resolution)) + return np.array(image).transpose(2, 0, 1) + + +def download_image(url): + image = PIL.Image.open(requests.get(url, stream=True).raw) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def load_tokenizer(config): + args = config + if config.task_type == "dreambooth": + # Load the tokenizer + if args.tokenizer_name: + logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0]) + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + logger.info("Loading tokenizer from pretrained model", ranks=[0]) + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + else: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + return tokenizer + +def load_text_endcoder(args): + # import correct text encoder class + if args.task_type == "dreambooth": + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load models and create wrapper for stable diffusion + + logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0]) + + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + else: + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + + + return text_encoder + +def main(): + args = parse_args() + + if args.task_type == "dreambooth": + assert args.instance_data_dir is not None, "instance_data_dir has to be provided for dreambooth training case" + + DATASET_NAME_MAPPING = {} + WANDB_TABLE_COL_NAMES = [] + if args.task_type == "image_to_image": + DATASET_NAME_MAPPING = { + "fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"), + } + WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"] + else: + DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), + } + + + if args.seed is None: + colossalai.launch_from_torch(config={}) + else: + colossalai.launch_from_torch(config={}, seed=args.seed) + + coordinator = DistCoordinator() + world_size = coordinator.world_size + + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + logging_dir = os.path.join(args.output_dir, args.logging_dir) + # Make one log on every process with the configuration for debugging. + + if args.seed is not None: + generator = torch.Generator(device=get_current_device()).manual_seed(args.seed) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + local_rank = coordinator.local_rank + local_rank = int(local_rank) + logger.info(f'local_rank: {local_rank}') + + if local_rank == 0: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + + # Handle the repository creation + if local_rank == 0: + if args.output_dir is not None: + logger.info(f"create output dir : {args.output_dir}") + os.makedirs(args.output_dir, exist_ok=True) + + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + # Load the tokenizer + tokenizer = load_tokenizer(args) + #load text_encoder + text_encoder = load_text_endcoder(args) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + + if args.externel_unet_path is None: + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.non_ema_revision) + else: + logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) + unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, + revision=args.revision, + low_cpu_mem_usage=False) + + if args.task_type == "image_to_image": + # InstructPix2Pix uses an additional image for conditioning. To accommodate that, + # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is + # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized + # from the pre-trained checkpoints. For the extra channels added to the first layer, they are + # initialized to zero. + logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + in_channels = 8 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + if args.use_lora: + unet.requires_grad_(False) + + # Set correct lora layers + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim) + + unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(unet.attn_processors) + + + # Create EMA for the unet. + if args.use_ema: + ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * world_size + ) + + + + # config optimizer for colossalai zero + optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.task_type != "dreambooth": + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.task_type == "text_to_image": + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + else: + if args.original_image_column is None: + original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + original_image_column = args.original_image_column + if original_image_column not in column_names: + raise ValueError( + f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}" + ) + + if args.edit_prompt_column is None: + edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + edit_prompt_column = args.edit_prompt_column + if edit_prompt_column not in column_names: + raise ValueError( + f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}" + ) + + if args.edited_image_column is None: + edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2] + else: + edited_image_column = args.edited_image_column + if edited_image_column not in column_names: + raise ValueError( + f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}" + ) + + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions_dispatch(task_type): + def tokenize_captions_1(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + def tokenize_captions_2(captions): + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + if task_type == "text_to_image": + return tokenize_captions_1 + return tokenize_captions_2 + + + # Preprocessing the datasets. + if args.task_type == "text_to_image": + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + else: + train_transforms = transforms.Compose( + [ + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + ] + ) + + tokenize_captions = tokenize_captions_dispatch(args.task_type) + def process_train_dispatch(task_type): + def preprocess_train_text_to_image(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples) + return examples + + def preprocess_images(examples): + original_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[original_image_column]] + ) + edited_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[edited_image_column]] + ) + # We need to ensure that the original and the edited images undergo the same + # augmentation transforms. + images = np.concatenate([original_images, edited_images]) + images = torch.tensor(images) + images = 2 * (images / 255) - 1 + return train_transforms(images) + + def preprocess_train_image_to_image(examples): + # Preprocess images. + preprocessed_images = preprocess_images(examples) + # Since the original and edited images were concatenated before + # applying the transformations, we need to separate them and reshape + # them accordingly. + original_images, edited_images = preprocessed_images.chunk(2) + original_images = original_images.reshape(-1, 3, args.resolution, args.resolution) + edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution) + + # Collate the preprocessed images into the `examples`. + examples["original_pixel_values"] = original_images + examples["edited_pixel_values"] = edited_images + + # Preprocess the captions. + captions = list(examples[edit_prompt_column]) + examples["input_ids"] = tokenize_captions(captions) + return examples + + if task_type == "text_to_image": + return preprocess_train_text_to_image + return preprocess_train_image_to_image + + + + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + if args.task_type == "dreambooth": + # prepare dataset + logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + else: + + preprocess_train = process_train_dispatch(args.task_type) + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn_dispatch(task_type): + def collate_fn_text_to_image(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return {"pixel_values": pixel_values, "input_ids": input_ids} + + def collate_fn_image_to_image(examples): + original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples]) + original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float() + edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) + edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return { + "original_pixel_values": original_pixel_values, + "edited_pixel_values": edited_pixel_values, + "input_ids": input_ids, + } + + def collate_fn_dreambooth(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + { + "input_ids": input_ids + }, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + if task_type == "text_to_image": + return collate_fn_text_to_image + elif task_type == "dreambooth": + return collate_fn_dreambooth + else: + return collate_fn_image_to_image + + # DataLoaders creation: + collate_fn = collate_fn_dispatch(args.task_type) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + if args.use_ema: + ema_unet.to(get_current_device()) + + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + text_encoder.to(get_current_device(), dtype=weight_dtype) + vae.to(get_current_device(), dtype=weight_dtype) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.train_batch_size * world_size * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable = (local_rank!=0)) + progress_bar.set_description("Steps") + + # Use Booster API to use Gemini/Zero with ColossalAI + unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler) + + torch.cuda.synchronize() + print("start training ... ") + + save_flag = False + for epoch in range(args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + for key, value in batch.items(): + batch[key] = value.to(get_current_device(), non_blocking=True) + + # Convert images to latent space + optimizer.zero_grad() + + if args.task_type == "text_to_image" or args.task_type == "dreambooth": + latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() + else: + latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + if args.input_perturbation: + new_noise = noise + args.input_perturbation * torch.randn_like(noise) + bsz = latents.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.input_perturbation: + noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + if args.task_type == "image_to_image": + # Get the additional image embedding for conditioning. + # Instead of getting a diagonal Gaussian here, we simply take the mode. + original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode() + + # Conditioning dropout to support classifier-free guidance during inference. For more details + # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. + if args.conditioning_dropout_prob is not None: + if args.seed is not None: + random_p = torch.rand(bsz, device=latents.device, generator=generator) + else: + random_p = torch.rand(bsz, device=latents.device) + + # Sample masks for the edit prompts. + prompt_mask = random_p < 2 * args.conditioning_dropout_prob + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + # Final text conditioning. + null_conditioning = text_encoder(tokenize_captions([""]).to(get_current_device()))[0] + encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) + + # Sample masks for the original images. + image_mask_dtype = original_image_embeds.dtype + image_mask = 1 - ( + (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype) + * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype) + ) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + # Final image conditioning. + original_image_embeds = image_mask * original_image_embeds + + # Concatenate the `original_image_embeds` with the `noisy_latents`. + concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1) + noisy_latents = concatenated_noisy_latents + + + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + + + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + global_step += 1 + progress_bar.update(1) + + + if global_step % args.checkpointing_steps == 0: + + if local_rank == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + + logger.info(f'train_loss : {loss.detach().item()} for global_step : {global_step}') + logger.info(f'lr: {lr_scheduler.get_last_lr()[0]}') + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + if global_step >= args.max_train_steps: + break + + torch.cuda.synchronize() + + + torch.cuda.synchronize() + booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin")) + logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}") + + +if __name__ == "__main__": + main() diff --git a/applications/stable-diffusion/text_img2img/stable_diffusion_trainer.py b/applications/stable-diffusion/text_img2img/stable_diffusion_trainer.py new file mode 100644 index 000000000..445e7da56 --- /dev/null +++ b/applications/stable-diffusion/text_img2img/stable_diffusion_trainer.py @@ -0,0 +1,844 @@ +import argparse +import logging +import math +import os +import PIL +import random +import shutil +from pathlib import Path +from typing import Optional +from dreambooth_utils import DreamBoothDataset, PromptDataset, get_full_repo_name + +import accelerate +import datasets +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers import AutoTokenizer, PretrainedConfig +from transformers.utils import ContextManagers +from parse_arguments import parse_args + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +if is_wandb_available(): + import wandb + + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + +def convert_to_np(image, resolution): + image = image.convert("RGB").resize((resolution, resolution)) + return np.array(image).transpose(2, 0, 1) + + +def download_image(url): + image = PIL.Image.open(requests.get(url, stream=True).raw) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path, revision): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + +def load_tokenizer(config): + args = config + if config.task_type == "dreambooth": + # Load the tokenizer + if args.tokenizer_name: + logger.info(f"Loading tokenizer from {args.tokenizer_name}") + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + revision=args.revision, + use_fast=False, + ) + elif args.pretrained_model_name_or_path: + logger.info("Loading tokenizer from pretrained model") + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + else: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + return tokenizer + +def load_text_endcoder(args): + # import correct text encoder class + if args.task_type == "dreambooth": + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load models and create wrapper for stable diffusion + + logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}") + + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + else: + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + + return text_encoder + + +def main(): + args = parse_args() + + if args.task_type == "dreambooth": + assert args.instance_data_dir is not None, "instance_data_dir has to be provided for dreambooth training case" + + DATASET_NAME_MAPPING = {} + WANDB_TABLE_COL_NAMES = [] + if args.task_type == "image_to_image": + DATASET_NAME_MAPPING = { + "fusing/instructpix2pix-1000-samples": ("input_image", "edit_prompt", "edited_image"), + } + WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"] + else: + DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), + } + + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + # Load the tokenizer + tokenizer = load_tokenizer(args) + #load text_encoder + text_encoder = load_text_endcoder(args) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + + if args.externel_unet_path is None: + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.non_ema_revision) + else: + logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}") + unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, + revision=args.revision, + low_cpu_mem_usage=False) + + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + + if args.externel_unet_path is None: + logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}") + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.non_ema_revision) + else: + logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}") + unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, + revision=args.revision, + low_cpu_mem_usage=False) + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + if args.task_type == "image_to_image": + # InstructPix2Pix uses an additional image for conditioning. To accommodate that, + # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is + # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized + # from the pre-trained checkpoints. For the extra channels added to the first layer, they are + # initialized to zero. + logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + in_channels = 8 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.task_type != "dreambooth": + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.task_type == "text_to_image": + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + else: + if args.original_image_column is None: + original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + original_image_column = args.original_image_column + if original_image_column not in column_names: + raise ValueError( + f"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.edit_prompt_column is None: + edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + edit_prompt_column = args.edit_prompt_column + if edit_prompt_column not in column_names: + raise ValueError( + f"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.edited_image_column is None: + edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2] + else: + edited_image_column = args.edited_image_column + if edited_image_column not in column_names: + raise ValueError( + f"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions_dispatch(task_type): + def tokenize_captions_1(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + def tokenize_captions_2(captions): + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + if task_type == "text_to_image": + return tokenize_captions_1 + return tokenize_captions_2 + + # Preprocessing the datasets. + if args.task_type == "text_to_image": + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + else: + train_transforms = transforms.Compose( + [ + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + ] + ) + + + tokenize_captions = tokenize_captions_dispatch(args.task_type) + def process_train_dispatch(task_type): + def preprocess_train_text_to_image(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples) + return examples + + def preprocess_images(examples): + original_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[original_image_column]] + ) + edited_images = np.concatenate( + [convert_to_np(image, args.resolution) for image in examples[edited_image_column]] + ) + # We need to ensure that the original and the edited images undergo the same + # augmentation transforms. + images = np.concatenate([original_images, edited_images]) + images = torch.tensor(images) + images = 2 * (images / 255) - 1 + return train_transforms(images) + + def preprocess_train_image_to_image(examples): + # Preprocess images. + preprocessed_images = preprocess_images(examples) + # Since the original and edited images were concatenated before + # applying the transformations, we need to separate them and reshape + # them accordingly. + original_images, edited_images = preprocessed_images.chunk(2) + original_images = original_images.reshape(-1, 3, args.resolution, args.resolution) + edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution) + + # Collate the preprocessed images into the `examples`. + examples["original_pixel_values"] = original_images + examples["edited_pixel_values"] = edited_images + + # Preprocess the captions. + captions = list(examples[edit_prompt_column]) + examples["input_ids"] = tokenize_captions(captions) + return examples + + if task_type == "text_to_image": + return preprocess_train_text_to_image + return preprocess_train_image_to_image + + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + + # Set the training transforms + if args.task_type == "dreambooth": + # prepare dataset + logger.info(f"Prepare dataset from {args.instance_data_dir}") + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + else: + preprocess_train = process_train_dispatch(args.task_type) + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn_dispatch(task_type): + def collate_fn_text_to_image(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return {"pixel_values": pixel_values, "input_ids": input_ids} + + def collate_fn_image_to_image(examples): + original_pixel_values = torch.stack([example["original_pixel_values"] for example in examples]) + original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float() + edited_pixel_values = torch.stack([example["edited_pixel_values"] for example in examples]) + edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return { + "original_pixel_values": original_pixel_values, + "edited_pixel_values": edited_pixel_values, + "input_ids": input_ids, + } + + def collate_fn_dreambooth(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + { + "input_ids": input_ids + }, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + if task_type == "text_to_image": + return collate_fn_text_to_image + elif task_type == "dreambooth": + return collate_fn_dreambooth + else: + return collate_fn_image_to_image + + # DataLoaders creation: + collate_fn = collate_fn_dispatch(args.task_type) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + if args.task_type == "text_to_image" or args.task_type == "dreambooth": + latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() + else: + latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + if args.input_perturbation: + new_noise = noise + args.input_perturbation * torch.randn_like(noise) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.input_perturbation: + noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + if args.task_type == "image_to_image": + # Get the additional image embedding for conditioning. + # Instead of getting a diagonal Gaussian here, we simply take the mode. + original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode() + + # Conditioning dropout to support classifier-free guidance during inference. For more details + # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. + if args.conditioning_dropout_prob is not None: + random_p = torch.rand(bsz, device=latents.device, generator=generator) + # Sample masks for the edit prompts. + prompt_mask = random_p < 2 * args.conditioning_dropout_prob + prompt_mask = prompt_mask.reshape(bsz, 1, 1) + # Final text conditioning. + null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0] + encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states) + + # Sample masks for the original images. + image_mask_dtype = original_image_embeds.dtype + image_mask = 1 - ( + (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype) + * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype) + ) + image_mask = image_mask.reshape(bsz, 1, 1, 1) + # Final image conditioning. + original_image_embeds = image_mask * original_image_embeds + + # Concatenate the `original_image_embeds` with the `noisy_latents`. + concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1) + noisy_latents = concatenated_noisy_latents + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + main() \ No newline at end of file