mirror of https://github.com/hpcaitech/ColossalAI
fix/transformer-verison (#2581)
parent
d3480396f8
commit
292c81ed7c
|
@ -52,7 +52,7 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
|
|||
|
||||
```
|
||||
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
|
||||
pip install transformers==4.19.2 diffusers invisible-watermark
|
||||
pip install transformers diffusers invisible-watermark
|
||||
```
|
||||
|
||||
#### Step 2: install lightning
|
||||
|
|
|
@ -18,7 +18,7 @@ dependencies:
|
|||
- test-tube>=0.7.5
|
||||
- streamlit==1.12.1
|
||||
- einops==0.3.0
|
||||
- transformers==4.19.2
|
||||
- transformers
|
||||
- webdataset==0.2.5
|
||||
- kornia==0.6
|
||||
- open_clip_torch==2.0.2
|
||||
|
|
|
@ -9,7 +9,7 @@ omegaconf==2.1.1
|
|||
test-tube>=0.7.5
|
||||
streamlit>=0.73.1
|
||||
einops==0.3.0
|
||||
transformers==4.19.2
|
||||
transformers
|
||||
webdataset==0.2.5
|
||||
open-clip-torch==2.7.0
|
||||
gradio==3.11
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
@ -133,9 +133,13 @@ def parse_args(input_args=None):
|
|||
default="cpu",
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument("--center_crop",
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to center crop images before resizing to resolution")
|
||||
help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."),
|
||||
)
|
||||
parser.add_argument("--train_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
|
@ -149,13 +153,6 @@ def parse_args(input_args=None):
|
|||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help=
|
||||
"Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
|
@ -356,7 +353,6 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
|
|||
|
||||
|
||||
def main(args):
|
||||
|
||||
if args.seed is None:
|
||||
colossalai.launch_from_torch(config={})
|
||||
else:
|
||||
|
@ -410,7 +406,8 @@ def main(args):
|
|||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
||||
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
|
@ -469,9 +466,8 @@ def main(args):
|
|||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
assert args.gradient_accumulation_steps == 1, "if using ColossalAI gradient_accumulation_steps must be set to 1."
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * world_size
|
||||
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
|
||||
|
||||
unet = gemini_zero_dpp(unet, args.placement)
|
||||
|
||||
|
@ -529,7 +525,7 @@ def main(args):
|
|||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
@ -537,8 +533,8 @@ def main(args):
|
|||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
|
@ -553,14 +549,14 @@ def main(args):
|
|||
text_encoder.to(get_current_device(), dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * world_size * args.gradient_accumulation_steps
|
||||
total_batch_size = args.train_batch_size * world_size
|
||||
|
||||
logger.info("***** Running training *****", ranks=[0])
|
||||
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
||||
|
@ -568,7 +564,6 @@ def main(args):
|
|||
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
|
|
Loading…
Reference in New Issue