import argparse import hashlib import math import os from pathlib import Path from typing import Optional import torch import torch.nn.functional as F import torch.utils.checkpoint from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from huggingface_hub import HfFolder, Repository, whoami from PIL import Image from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext disable_existing_loggers() logger = get_dist_logger() def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation else: raise ValueError(f"{model_class} is not supported.") def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--instance_data_dir", type=str, default=None, required=True, help="A folder containing the training data of instance images.", ) parser.add_argument( "--class_data_dir", type=str, default=None, required=False, help="A folder containing the training data of class images.", ) parser.add_argument( "--instance_prompt", type=str, default="a photo of sks dog", required=False, help="The prompt with identifier specifying the instance", ) parser.add_argument( "--class_prompt", type=str, default=None, help="The prompt to specify images in the same class as provided instance images.", ) parser.add_argument( "--with_prior_preservation", default=False, action="store_true", help="Flag to add prior preservation loss.", ) parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") parser.add_argument( "--num_class_images", type=int, default=100, help=("Minimal class images for prior preservation loss. If there are not enough images already present in" " class_data_dir, additional images will be sampled with class_prompt."), ) parser.add_argument( "--output_dir", type=str, default="text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=512, help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution"), ) parser.add_argument( "--placement", type=str, default="cpu", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", ) parser.add_argument("--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution") parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.") parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]'), ) parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.") parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank if args.with_prior_preservation: if args.class_data_dir is None: raise ValueError("You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") else: if args.class_data_dir is not None: logger.warning("You need not use --class_data_dir without --with_prior_preservation.") if args.class_prompt is not None: logger.warning("You need not use --class_prompt without --with_prior_preservation.") return args class DreamBoothDataset(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. It pre-processes the images and the tokenizes prompts. """ def __init__( self, instance_data_root, instance_prompt, tokenizer, class_data_root=None, class_prompt=None, size=512, center_crop=False, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") self.instance_images_path = list(Path(instance_data_root).iterdir()) self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images if class_data_root is not None: self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: self.class_data_root = None self.image_transforms = transforms.Compose([ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def __len__(self): return self._length def __getitem__(self, index): example = {} instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) example["class_prompt_ids"] = self.tokenizer( self.class_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, ).input_ids return example class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." def __init__(self, prompt, num_samples): self.prompt = prompt self.num_samples = num_samples def __len__(self): return self.num_samples def __getitem__(self, index): example = {} example["prompt"] = self.prompt example["index"] = index return example def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() if organization is None: username = whoami(token)["name"] return f"{username}/{model_id}" else: return f"{organization}/{model_id}" # Gemini + ZeRO DDP def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=64) return model def main(args): colossalai.launch_from_torch(config={}) if args.seed is not None: gpc.set_seed(args.seed) if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) if not class_images_dir.exists(): class_images_dir.mkdir(parents=True) cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32 pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None, revision=args.revision, ) pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) pipeline.to(get_current_device()) for example in tqdm( sample_dataloader, desc="Generating class images", disable=not gpc.get_local_rank(ParallelMode.DATA) == 0, ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): hash_image = hashlib.sha1(image.tobytes()).hexdigest() image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) del pipeline # Handle the repository creation if gpc.get_local_rank(ParallelMode.DATA) == 0: if args.push_to_hub: if args.hub_model_id is None: repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) else: repo_name = args.hub_model_id repo = Repository(args.output_dir, clone_from=repo_name) with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: if "step_*" not in gitignore: gitignore.write("step_*\n") if "epoch_*" not in gitignore: gitignore.write("epoch_*\n") elif args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) # Load the tokenizer if args.tokenizer_name: logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0]) tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name, revision=args.revision, use_fast=False, ) elif args.pretrained_model_name_or_path: logger.info("Loading tokenizer from pretrained model", ranks=[0]) tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False, ) # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path) # Load models and create wrapper for stable diffusion logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0]) text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, ) logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0]) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, ) logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) with ColoInitContext(device=get_current_device()): unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False) vae.requires_grad_(False) text_encoder.requires_grad_(False) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.scale_lr: args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) unet = gemini_zero_dpp(unet, args.placement) # config optimizer for colossalai zero optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # prepare dataset logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, ) def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if args.with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( { "input_ids": input_ids }, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", ).input_ids batch = { "input_ids": input_ids, "pixel_values": pixel_values, } return batch train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move text_encode and vae to gpu. # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. vae.to(get_current_device(), dtype=weight_dtype) text_encoder.to(get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # Train! total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps logger.info("***** Running training *****", ranks=[0]) logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0]) logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0]) logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) # Only show the progress bar once on each machine. progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0) progress_bar.set_description("Steps") global_step = 0 torch.cuda.synchronize() for epoch in range(args.num_train_epochs): unet.train() for step, batch in enumerate(train_dataloader): torch.cuda.reset_peak_memory_stats() # Move batch to gpu for key, value in batch.items(): batch[key] = value.to(get_current_device(), non_blocking=True) # Convert images to latent space optimizer.zero_grad() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") optimizer.backward(loss) optimizer.step() lr_scheduler.step() logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0]) # Checks if the accelerator has performed an optimization step behind the scenes progress_bar.update(1) global_step += 1 logs = { "loss": loss.detach().item(), "lr": optimizer.param_groups[0]["lr"], } # lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step % args.save_steps == 0: torch.cuda.synchronize() torch_unet = get_static_torch_model(unet) if gpc.get_local_rank(ParallelMode.DATA) == 0: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=torch_unet, revision=args.revision, ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") pipeline.save_pretrained(save_path) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) if global_step >= args.max_train_steps: break torch.cuda.synchronize() unet = get_static_torch_model(unet) if gpc.get_local_rank(ParallelMode.DATA) == 0: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet, revision=args.revision, ) pipeline.save_pretrained(args.output_dir) logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) if __name__ == "__main__": args = parse_args() main(args)