fix/transformer-verison (#2581)

pull/2637/head
Fazzie-Maqianli 2023-02-08 13:50:27 +08:00 committed by GitHub
parent d3480396f8
commit 292c81ed7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 24 deletions

View File

@ -52,7 +52,7 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
```
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install transformers==4.19.2 diffusers invisible-watermark
pip install transformers diffusers invisible-watermark
```
#### Step 2: install lightning

View File

@ -18,7 +18,7 @@ dependencies:
- test-tube>=0.7.5
- streamlit==1.12.1
- einops==0.3.0
- transformers==4.19.2
- transformers
- webdataset==0.2.5
- kornia==0.6
- open_clip_torch==2.0.2

View File

@ -9,7 +9,7 @@ omegaconf==2.1.1
test-tube>=0.7.5
streamlit>=0.73.1
einops==0.3.0
transformers==4.19.2
transformers
webdataset==0.2.5
open-clip-torch==2.7.0
gradio==3.11

View File

@ -10,7 +10,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
@ -133,9 +133,13 @@ def parse_args(input_args=None):
default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument("--center_crop",
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help="Whether to center crop images before resizing to resolution")
help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."),
)
parser.add_argument("--train_batch_size",
type=int,
default=4,
@ -149,13 +153,6 @@ def parse_args(input_args=None):
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help=
"Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
@ -356,7 +353,6 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
def main(args):
if args.seed is None:
colossalai.launch_from_torch(config={})
else:
@ -410,7 +406,8 @@ def main(args):
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
@ -469,9 +466,8 @@ def main(args):
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
assert args.gradient_accumulation_steps == 1, "if using ColossalAI gradient_accumulation_steps must be set to 1."
if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * world_size
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
unet = gemini_zero_dpp(unet, args.placement)
@ -529,7 +525,7 @@ def main(args):
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
@ -537,8 +533,8 @@ def main(args):
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
@ -553,14 +549,14 @@ def main(args):
text_encoder.to(get_current_device(), dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = args.train_batch_size * world_size * args.gradient_accumulation_steps
total_batch_size = args.train_batch_size * world_size
logger.info("***** Running training *****", ranks=[0])
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
@ -568,7 +564,6 @@ def main(args):
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
# Only show the progress bar once on each machine.