[exmaple] fix dreamblooth format (#2315)

pull/2317/head
Fazzie-Maqianli 2023-01-04 16:20:00 +08:00 committed by GitHub
parent da1c47f060
commit a9b27b9265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 90 additions and 101 deletions

View File

@ -108,7 +108,7 @@ lightning:
params:
use_chunk: True
enable_distributed_storage: True
placement_policy: auto
placement_policy: cuda
force_outputs_fp32: true
log_every_n_steps: 2

View File

@ -105,7 +105,7 @@ lightning:
params:
use_chunk: True
enable_distributed_storage: True
placement_policy: auto
placement_policy: cuda
force_outputs_fp32: true
log_every_n_steps: 2

View File

@ -109,7 +109,7 @@ lightning:
params:
use_chunk: True
enable_distributed_storage: True
placement_policy: auto
placement_policy: cuda
force_outputs_fp32: true
log_every_n_steps: 2

View File

@ -102,7 +102,7 @@ lightning:
params:
use_chunk: True
enable_distributed_storage: True
placement_policy: auto
placement_policy: cuda
force_outputs_fp32: true
log_every_n_steps: 2

View File

@ -1,38 +1,32 @@
import argparse
import hashlib
import itertools
import math
import os
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.checkpoint
from copy import deepcopy
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from packaging import version
from PIL import Image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.parallel.utils import convert_to_torch_module
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
disable_existing_loggers()
logger = get_dist_logger()
@ -118,8 +112,10 @@ def parse_args(input_args=None):
"--num_class_images",
type=int,
default=100,
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."),
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
@ -132,23 +128,26 @@ def parse_args(input_args=None):
"--resolution",
type=int,
default=512,
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"),
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument("--center_crop",
action="store_true",
help="Whether to center crop images before resizing to resolution")
parser.add_argument("--train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.")
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
parser.add_argument(
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
@ -184,16 +183,17 @@ def parse_args(input_args=None):
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
@ -208,8 +208,10 @@ def parse_args(input_args=None):
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
@ -219,7 +221,8 @@ def parse_args(input_args=None):
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@ -285,12 +288,14 @@ class DreamBoothDataset(Dataset):
else:
self.class_data_root = None
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
@ -352,26 +357,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32
)
return model
@ -383,7 +373,7 @@ def main(args):
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"clip_grad_norm": args.max_grad_norm,
}
colossalai.launch_from_torch(config=config)
pg = ProcessGroup()
@ -414,9 +404,11 @@ def main(args):
pipeline.to(get_current_device())
for example in tqdm(sample_dataloader,
desc="Generating class images",
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0):
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
@ -466,23 +458,24 @@ def main(args):
logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])
text_encoder = text_encoder_cls.from_pretrained(args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,)
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
with ColoInitContext():
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@ -491,7 +484,7 @@ def main(args):
unet.enable_gradient_checkpointing()
if args.scale_lr:
args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2)
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2
unet = gemini_zero_dpp(unet, pg, args.placement)
@ -502,7 +495,7 @@ def main(args):
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# prepare dataset
logger.info(f"Prepare dataset", ranks=[0])
logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
@ -527,9 +520,7 @@ def main(args):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad(
{
"input_ids": input_ids
},
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
@ -541,11 +532,9 @@ def main(args):
}
return batch
train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=1)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@ -662,8 +651,8 @@ def main(args):
global_step += 1
logs = {
"loss": loss.detach().item(),
"lr": optimizer.param_groups[0]['lr']
} #lr_scheduler.get_last_lr()[0]}
"lr": optimizer.param_groups[0]["lr"],
} # lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step % args.save_steps == 0:
@ -681,15 +670,15 @@ def main(args):
break
torch.cuda.synchronize()
unet=convert_to_torch_module(unet)
unet = convert_to_torch_module(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])