mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
886 lines
34 KiB
886 lines
34 KiB
import argparse |
|
import datetime |
|
import glob |
|
import os |
|
import sys |
|
import time |
|
from functools import partial |
|
|
|
import lightning.pytorch as pl |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from ldm.models.diffusion.ddpm import LatentDiffusion |
|
from lightning.pytorch import seed_everything |
|
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint |
|
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger |
|
from lightning.pytorch.strategies import ColossalAIStrategy, DDPStrategy |
|
from lightning.pytorch.trainer import Trainer |
|
from lightning.pytorch.utilities import rank_zero_info, rank_zero_only |
|
from omegaconf import OmegaConf |
|
from packaging import version |
|
from PIL import Image |
|
from prefetch_generator import BackgroundGenerator |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
LIGHTNING_PACK_NAME = "lightning.pytorch." |
|
|
|
from ldm.data.base import Txt2ImgIterableBaseDataset |
|
from ldm.util import instantiate_from_config |
|
|
|
# from ldm.modules.attention import enable_flash_attentions |
|
|
|
|
|
class DataLoaderX(DataLoader): |
|
# A custom data loader class that inherits from DataLoader |
|
def __iter__(self): |
|
# Overriding the __iter__ method of DataLoader to return a BackgroundGenerator |
|
# This is to enable data loading in the background to improve training performance |
|
return BackgroundGenerator(super().__iter__()) |
|
|
|
|
|
def get_parser(**parser_kwargs): |
|
# A function to create an ArgumentParser object and add arguments to it |
|
|
|
def str2bool(v): |
|
# A helper function to parse boolean values from command line arguments |
|
if isinstance(v, bool): |
|
return v |
|
if v.lower() in ("yes", "true", "t", "y", "1"): |
|
return True |
|
elif v.lower() in ("no", "false", "f", "n", "0"): |
|
return False |
|
else: |
|
raise argparse.ArgumentTypeError("Boolean value expected.") |
|
|
|
# Create an ArgumentParser object with specifies kwargs |
|
parser = argparse.ArgumentParser(**parser_kwargs) |
|
|
|
# Add various command line arguments with their default values and descriptions |
|
parser.add_argument( |
|
"-n", |
|
"--name", |
|
type=str, |
|
const=True, |
|
default="", |
|
nargs="?", |
|
help="postfix for logdir", |
|
) |
|
parser.add_argument( |
|
"-r", |
|
"--resume", |
|
type=str, |
|
const=True, |
|
default="", |
|
nargs="?", |
|
help="resume from logdir or checkpoint in logdir", |
|
) |
|
parser.add_argument( |
|
"-b", |
|
"--base", |
|
nargs="*", |
|
metavar="base_config.yaml", |
|
help="paths to base configs. Loaded from left-to-right. " |
|
"Parameters can be overwritten or added with command-line options of the form `--key value`.", |
|
default=list(), |
|
) |
|
parser.add_argument( |
|
"-t", |
|
"--train", |
|
type=str2bool, |
|
const=True, |
|
default=False, |
|
nargs="?", |
|
help="train", |
|
) |
|
parser.add_argument( |
|
"--no-test", |
|
type=str2bool, |
|
const=True, |
|
default=False, |
|
nargs="?", |
|
help="disable test", |
|
) |
|
parser.add_argument( |
|
"-p", |
|
"--project", |
|
help="name of new or path to existing project", |
|
) |
|
parser.add_argument( |
|
"-c", |
|
"--ckpt", |
|
type=str, |
|
const=True, |
|
default="", |
|
nargs="?", |
|
help="load pretrained checkpoint from stable AI", |
|
) |
|
parser.add_argument( |
|
"-d", |
|
"--debug", |
|
type=str2bool, |
|
nargs="?", |
|
const=True, |
|
default=False, |
|
help="enable post-mortem debugging", |
|
) |
|
parser.add_argument( |
|
"-s", |
|
"--seed", |
|
type=int, |
|
default=23, |
|
help="seed for seed_everything", |
|
) |
|
parser.add_argument( |
|
"-f", |
|
"--postfix", |
|
type=str, |
|
default="", |
|
help="post-postfix for default name", |
|
) |
|
parser.add_argument( |
|
"-l", |
|
"--logdir", |
|
type=str, |
|
default="logs", |
|
help="directory for logging dat shit", |
|
) |
|
parser.add_argument( |
|
"--scale_lr", |
|
type=str2bool, |
|
nargs="?", |
|
const=True, |
|
default=True, |
|
help="scale base-lr by ngpu * batch_size * n_accumulate", |
|
) |
|
|
|
return parser |
|
|
|
|
|
# A function that returns the non-default arguments between two objects |
|
def nondefault_trainer_args(opt): |
|
# create an argument parser |
|
parser = argparse.ArgumentParser() |
|
# add pytorch lightning trainer default arguments |
|
parser = Trainer.add_argparse_args(parser) |
|
# parse the empty arguments to obtain the default values |
|
args = parser.parse_args([]) |
|
# return all non-default arguments |
|
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) |
|
|
|
|
|
# A dataset wrapper class to create a pytorch dataset from an arbitrary object |
|
class WrappedDataset(Dataset): |
|
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" |
|
|
|
def __init__(self, dataset): |
|
self.data = dataset |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
return self.data[idx] |
|
|
|
|
|
# A function to initialize worker processes |
|
def worker_init_fn(_): |
|
worker_info = torch.utils.data.get_worker_info() |
|
|
|
dataset = worker_info.dataset |
|
worker_id = worker_info.id |
|
|
|
if isinstance(dataset, Txt2ImgIterableBaseDataset): |
|
# divide the dataset into equal parts for each worker |
|
split_size = dataset.num_records // worker_info.num_workers |
|
# set the sample IDs for the current worker |
|
# reset num_records to the true number to retain reliable length information |
|
dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size] |
|
# set the seed for the current worker |
|
current_id = np.random.choice(len(np.random.get_state()[1]), 1) |
|
return np.random.seed(np.random.get_state()[1][current_id] + worker_id) |
|
else: |
|
return np.random.seed(np.random.get_state()[1][0] + worker_id) |
|
|
|
|
|
# Provide functionality for creating data loaders based on provided dataset configurations |
|
class DataModuleFromConfig(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
batch_size, |
|
train=None, |
|
validation=None, |
|
test=None, |
|
predict=None, |
|
wrap=False, |
|
num_workers=None, |
|
shuffle_test_loader=False, |
|
use_worker_init_fn=False, |
|
shuffle_val_dataloader=False, |
|
): |
|
super().__init__() |
|
# Set data module attributes |
|
self.batch_size = batch_size |
|
self.dataset_configs = dict() |
|
self.num_workers = num_workers if num_workers is not None else batch_size * 2 |
|
self.use_worker_init_fn = use_worker_init_fn |
|
# If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method |
|
if train is not None: |
|
self.dataset_configs["train"] = train |
|
self.train_dataloader = self._train_dataloader |
|
if validation is not None: |
|
self.dataset_configs["validation"] = validation |
|
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) |
|
if test is not None: |
|
self.dataset_configs["test"] = test |
|
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) |
|
if predict is not None: |
|
self.dataset_configs["predict"] = predict |
|
self.predict_dataloader = self._predict_dataloader |
|
self.wrap = wrap |
|
|
|
def prepare_data(self): |
|
# Instantiate datasets |
|
for data_cfg in self.dataset_configs.values(): |
|
instantiate_from_config(data_cfg) |
|
|
|
def setup(self, stage=None): |
|
# Instantiate datasets from the dataset configs |
|
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) |
|
|
|
# If wrap is true, create a WrappedDataset for each dataset |
|
if self.wrap: |
|
for k in self.datasets: |
|
self.datasets[k] = WrappedDataset(self.datasets[k]) |
|
|
|
def _train_dataloader(self): |
|
# Check if the train dataset is iterable |
|
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) |
|
# Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True |
|
if is_iterable_dataset or self.use_worker_init_fn: |
|
init_fn = worker_init_fn |
|
else: |
|
init_fn = None |
|
# Return a DataLoaderX object for the train dataset |
|
return DataLoaderX( |
|
self.datasets["train"], |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
shuffle=False if is_iterable_dataset else True, |
|
worker_init_fn=init_fn, |
|
) |
|
|
|
def _val_dataloader(self, shuffle=False): |
|
# Check if the validation dataset is iterable |
|
if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: |
|
init_fn = worker_init_fn |
|
else: |
|
init_fn = None |
|
# Return a DataLoaderX object for the validation dataset |
|
return DataLoaderX( |
|
self.datasets["validation"], |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
worker_init_fn=init_fn, |
|
shuffle=shuffle, |
|
) |
|
|
|
def _test_dataloader(self, shuffle=False): |
|
# Check if the test dataset is iterable |
|
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) |
|
# Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True |
|
if is_iterable_dataset or self.use_worker_init_fn: |
|
init_fn = worker_init_fn |
|
else: |
|
init_fn = None |
|
|
|
# do not shuffle dataloader for iterable dataset |
|
shuffle = shuffle and (not is_iterable_dataset) |
|
|
|
return DataLoaderX( |
|
self.datasets["test"], |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
worker_init_fn=init_fn, |
|
shuffle=shuffle, |
|
) |
|
|
|
def _predict_dataloader(self, shuffle=False): |
|
if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: |
|
init_fn = worker_init_fn |
|
else: |
|
init_fn = None |
|
return DataLoaderX( |
|
self.datasets["predict"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn |
|
) |
|
|
|
|
|
class SetupCallback(Callback): |
|
# Initialize the callback with the necessary parameters |
|
|
|
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): |
|
super().__init__() |
|
self.resume = resume |
|
self.now = now |
|
self.logdir = logdir |
|
self.ckptdir = ckptdir |
|
self.cfgdir = cfgdir |
|
self.config = config |
|
self.lightning_config = lightning_config |
|
|
|
# Save a checkpoint if training is interrupted with keyboard interrupt |
|
def on_keyboard_interrupt(self, trainer, pl_module): |
|
if trainer.global_rank == 0: |
|
print("Summoning checkpoint.") |
|
ckpt_path = os.path.join(self.ckptdir, "last.ckpt") |
|
trainer.save_checkpoint(ckpt_path) |
|
|
|
# Create necessary directories and save configuration files before training starts |
|
# def on_pretrain_routine_start(self, trainer, pl_module): |
|
def on_fit_start(self, trainer, pl_module): |
|
if trainer.global_rank == 0: |
|
# Create logdirs and save configs |
|
os.makedirs(self.logdir, exist_ok=True) |
|
os.makedirs(self.ckptdir, exist_ok=True) |
|
os.makedirs(self.cfgdir, exist_ok=True) |
|
|
|
# Create trainstep checkpoint directory if necessary |
|
if "callbacks" in self.lightning_config: |
|
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: |
|
os.makedirs(os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True) |
|
print("Project config") |
|
print(OmegaConf.to_yaml(self.config)) |
|
OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) |
|
|
|
# Save project config and lightning config as YAML files |
|
print("Lightning config") |
|
print(OmegaConf.to_yaml(self.lightning_config)) |
|
OmegaConf.save( |
|
OmegaConf.create({"lightning": self.lightning_config}), |
|
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), |
|
) |
|
|
|
# Remove log directory if resuming training and directory already exists |
|
else: |
|
# ModelCheckpoint callback created log directory --- remove it |
|
if not self.resume and os.path.exists(self.logdir): |
|
dst, name = os.path.split(self.logdir) |
|
dst = os.path.join(dst, "child_runs", name) |
|
os.makedirs(os.path.split(dst)[0], exist_ok=True) |
|
try: |
|
os.rename(self.logdir, dst) |
|
except FileNotFoundError: |
|
pass |
|
|
|
# def on_fit_end(self, trainer, pl_module): |
|
# if trainer.global_rank == 0: |
|
# ckpt_path = os.path.join(self.ckptdir, "last.ckpt") |
|
# rank_zero_info(f"Saving final checkpoint in {ckpt_path}.") |
|
# trainer.save_checkpoint(ckpt_path) |
|
|
|
|
|
# PyTorch Lightning callback for logging images during training and validation of a deep learning model |
|
class ImageLogger(Callback): |
|
def __init__( |
|
self, |
|
batch_frequency, # Frequency of batches on which to log images |
|
max_images, # Maximum number of images to log |
|
clamp=True, # Whether to clamp pixel values to [-1,1] |
|
increase_log_steps=True, # Whether to increase frequency of log steps exponentially |
|
rescale=True, # Whether to rescale pixel values to [0,1] |
|
disabled=False, # Whether to disable logging |
|
log_on_batch_idx=False, # Whether to log on batch index instead of global step |
|
log_first_step=False, # Whether to log on the first step |
|
log_images_kwargs=None, |
|
): # Additional keyword arguments to pass to log_images method |
|
super().__init__() |
|
self.rescale = rescale |
|
self.batch_freq = batch_frequency |
|
self.max_images = max_images |
|
self.logger_log_images = { |
|
# Dictionary of logger classes and their corresponding logging methods |
|
pl.loggers.CSVLogger: self._testtube, |
|
} |
|
# Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency |
|
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] |
|
if not increase_log_steps: |
|
self.log_steps = [self.batch_freq] |
|
self.clamp = clamp |
|
self.disabled = disabled |
|
self.log_on_batch_idx = log_on_batch_idx |
|
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} |
|
self.log_first_step = log_first_step |
|
|
|
@rank_zero_only # Ensure that only the first process in distributed training executes this method |
|
def _testtube( |
|
self, # The PyTorch Lightning module |
|
pl_module, # A dictionary of images to log. |
|
images, # |
|
batch_idx, # The batch index. |
|
split, # The split (train/val) on which to log the images |
|
): |
|
# Method for logging images using test-tube logger |
|
for k in images: |
|
grid = torchvision.utils.make_grid(images[k]) |
|
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w |
|
|
|
tag = f"{split}/{k}" |
|
# Add image grid to logger's experiment |
|
pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) |
|
|
|
@rank_zero_only |
|
def log_local( |
|
self, |
|
save_dir, |
|
split, # The split (train/val) on which to log the images |
|
images, # A dictionary of images to log |
|
global_step, # The global step |
|
current_epoch, # The current epoch. |
|
batch_idx, |
|
): |
|
# Method for saving image grids to local file system |
|
root = os.path.join(save_dir, "images", split) |
|
for k in images: |
|
grid = torchvision.utils.make_grid(images[k], nrow=4) |
|
if self.rescale: |
|
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w |
|
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) |
|
grid = grid.numpy() |
|
grid = (grid * 255).astype(np.uint8) |
|
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) |
|
path = os.path.join(root, filename) |
|
os.makedirs(os.path.split(path)[0], exist_ok=True) |
|
# Save image grid as PNG file |
|
Image.fromarray(grid).save(path) |
|
|
|
def log_img(self, pl_module, batch, batch_idx, split="train"): |
|
# Function for logging images to both the logger and local file system. |
|
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step |
|
# check if it's time to log an image batch |
|
if ( |
|
self.check_frequency(check_idx) |
|
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 |
|
and callable(pl_module.log_images) |
|
and self.max_images > 0 |
|
): |
|
# Get logger type and check if training mode is on |
|
logger = type(pl_module.logger) |
|
|
|
is_train = pl_module.training |
|
if is_train: |
|
pl_module.eval() |
|
|
|
with torch.no_grad(): |
|
# Get images from log_images method of the pl_module |
|
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) |
|
|
|
# Clip images if specified and convert to CPU tensor |
|
for k in images: |
|
N = min(images[k].shape[0], self.max_images) |
|
images[k] = images[k][:N] |
|
if isinstance(images[k], torch.Tensor): |
|
images[k] = images[k].detach().cpu() |
|
if self.clamp: |
|
images[k] = torch.clamp(images[k], -1.0, 1.0) |
|
|
|
# Log images locally to file system |
|
self.log_local( |
|
pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx |
|
) |
|
|
|
# log the images using the logger |
|
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) |
|
logger_log_images(pl_module, images, pl_module.global_step, split) |
|
|
|
# switch back to training mode if necessary |
|
if is_train: |
|
pl_module.train() |
|
|
|
# The function checks if it's time to log an image batch |
|
def check_frequency(self, check_idx): |
|
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( |
|
check_idx > 0 or self.log_first_step |
|
): |
|
try: |
|
self.log_steps.pop(0) |
|
except IndexError as e: |
|
print(e) |
|
return True |
|
return False |
|
|
|
# Log images on train batch end if logging is not disabled |
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
# if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): |
|
# self.log_img(pl_module, batch, batch_idx, split="train") |
|
pass |
|
|
|
# Log images on validation batch end if logging is not disabled and in validation mode |
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
if not self.disabled and pl_module.global_step > 0: |
|
self.log_img(pl_module, batch, batch_idx, split="val") |
|
# log gradients during calibration if necessary |
|
if hasattr(pl_module, "calibrate_grad_norm"): |
|
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: |
|
self.log_gradients(trainer, pl_module, batch_idx=batch_idx) |
|
|
|
|
|
class CUDACallback(Callback): |
|
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py |
|
|
|
def on_train_start(self, trainer, pl_module): |
|
rank_zero_info("Training is starting") |
|
|
|
# the method is called at the end of each training epoch |
|
def on_train_end(self, trainer, pl_module): |
|
rank_zero_info("Training is ending") |
|
|
|
def on_train_epoch_start(self, trainer, pl_module): |
|
# Reset the memory use counter |
|
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) |
|
torch.cuda.synchronize(trainer.strategy.root_device.index) |
|
self.start_time = time.time() |
|
|
|
def on_train_epoch_end(self, trainer, pl_module): |
|
torch.cuda.synchronize(trainer.strategy.root_device.index) |
|
max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20 |
|
epoch_time = time.time() - self.start_time |
|
|
|
try: |
|
max_memory = trainer.strategy.reduce(max_memory) |
|
epoch_time = trainer.strategy.reduce(epoch_time) |
|
|
|
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") |
|
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") |
|
except AttributeError: |
|
pass |
|
|
|
|
|
if __name__ == "__main__": |
|
# custom parser to specify config files, train, test and debug mode, |
|
# postfix, resume. |
|
# `--key value` arguments are interpreted as arguments to the trainer. |
|
# `nested.key=value` arguments are interpreted as config parameters. |
|
# configs are merged from left-to-right followed by command line parameters. |
|
|
|
# model: |
|
# base_learning_rate: float |
|
# target: path to lightning module |
|
# params: |
|
# key: value |
|
# data: |
|
# target: main.DataModuleFromConfig |
|
# params: |
|
# batch_size: int |
|
# wrap: bool |
|
# train: |
|
# target: path to train dataset |
|
# params: |
|
# key: value |
|
# validation: |
|
# target: path to validation dataset |
|
# params: |
|
# key: value |
|
# test: |
|
# target: path to test dataset |
|
# params: |
|
# key: value |
|
# lightning: (optional, has sane defaults and can be specified on cmdline) |
|
# trainer: |
|
# additional arguments to trainer |
|
# logger: |
|
# logger to instantiate |
|
# modelcheckpoint: |
|
# modelcheckpoint to instantiate |
|
# callbacks: |
|
# callback1: |
|
# target: importpath |
|
# params: |
|
# key: value |
|
|
|
# get the current time to create a new logging directory |
|
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
|
|
|
# add cwd for convenience and to make classes in this file available when |
|
# running as `python main.py` |
|
# (in particular `main.DataModuleFromConfig`) |
|
sys.path.append(os.getcwd()) |
|
|
|
parser = get_parser() |
|
parser = Trainer.add_argparse_args(parser) |
|
|
|
opt, unknown = parser.parse_known_args() |
|
# Verify the arguments are both specified |
|
if opt.name and opt.resume: |
|
raise ValueError( |
|
"-n/--name and -r/--resume cannot be specified both." |
|
"If you want to resume training in a new log folder, " |
|
"use -n/--name in combination with --resume_from_checkpoint" |
|
) |
|
|
|
# Check if the "resume" option is specified, resume training from the checkpoint if it is true |
|
ckpt = None |
|
if opt.resume: |
|
rank_zero_info("Resuming from {}".format(opt.resume)) |
|
if not os.path.exists(opt.resume): |
|
raise ValueError("Cannot find {}".format(opt.resume)) |
|
if os.path.isfile(opt.resume): |
|
paths = opt.resume.split("/") |
|
# idx = len(paths)-paths[::-1].index("logs")+1 |
|
# logdir = "/".join(paths[:idx]) |
|
logdir = "/".join(paths[:-2]) |
|
rank_zero_info("logdir: {}".format(logdir)) |
|
ckpt = opt.resume |
|
else: |
|
assert os.path.isdir(opt.resume), opt.resume |
|
logdir = opt.resume.rstrip("/") |
|
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") |
|
|
|
# Finds all ".yaml" configuration files in the log directory and adds them to the list of base configurations |
|
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) |
|
opt.base = base_configs + opt.base |
|
# Gets the name of the current log directory by splitting the path and taking the last element. |
|
_tmp = logdir.split("/") |
|
nowname = _tmp[-1] |
|
else: |
|
if opt.name: |
|
name = "_" + opt.name |
|
elif opt.base: |
|
rank_zero_info("Using base config {}".format(opt.base)) |
|
cfg_fname = os.path.split(opt.base[0])[-1] |
|
cfg_name = os.path.splitext(cfg_fname)[0] |
|
name = "_" + cfg_name |
|
else: |
|
name = "" |
|
nowname = now + name + opt.postfix |
|
logdir = os.path.join(opt.logdir, nowname) |
|
|
|
# Sets the checkpoint path of the 'ckpt' option is specified |
|
if opt.ckpt: |
|
ckpt = opt.ckpt |
|
|
|
# Create the checkpoint and configuration directories within the log directory. |
|
ckptdir = os.path.join(logdir, "checkpoints") |
|
cfgdir = os.path.join(logdir, "configs") |
|
# Sets the seed for the random number generator to ensure reproducibility |
|
seed_everything(opt.seed) |
|
|
|
# Initialize and save configuration using teh OmegaConf library. |
|
try: |
|
# init and save configs |
|
configs = [OmegaConf.load(cfg) for cfg in opt.base] |
|
cli = OmegaConf.from_dotlist(unknown) |
|
config = OmegaConf.merge(*configs, cli) |
|
lightning_config = config.pop("lightning", OmegaConf.create()) |
|
# merge trainer cli with config |
|
trainer_config = lightning_config.get("trainer", OmegaConf.create()) |
|
|
|
for k in nondefault_trainer_args(opt): |
|
trainer_config[k] = getattr(opt, k) |
|
|
|
# Check whether the accelerator is gpu |
|
if not trainer_config["accelerator"] == "gpu": |
|
del trainer_config["accelerator"] |
|
cpu = True |
|
else: |
|
cpu = False |
|
trainer_opt = argparse.Namespace(**trainer_config) |
|
lightning_config.trainer = trainer_config |
|
|
|
# model |
|
use_fp16 = trainer_config.get("precision", 32) == 16 |
|
if use_fp16: |
|
config.model["params"].update({"use_fp16": True}) |
|
else: |
|
config.model["params"].update({"use_fp16": False}) |
|
|
|
if ckpt is not None: |
|
# If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt |
|
config.model["params"].update({"ckpt": ckpt}) |
|
rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) |
|
|
|
model = LatentDiffusion(**config.model.get("params", dict())) |
|
# trainer and callbacks |
|
trainer_kwargs = dict() |
|
|
|
# config the logger |
|
# Default logger configs to log training metrics during the training process. |
|
default_logger_cfgs = { |
|
"wandb": { |
|
"name": nowname, |
|
"save_dir": logdir, |
|
"offline": opt.debug, |
|
"id": nowname, |
|
}, |
|
"tensorboard": {"save_dir": logdir, "name": "diff_tb", "log_graph": True}, |
|
} |
|
|
|
# Set up the logger for TensorBoard |
|
default_logger_cfg = default_logger_cfgs["tensorboard"] |
|
if "logger" in lightning_config: |
|
logger_cfg = lightning_config.logger |
|
trainer_kwargs["logger"] = WandbLogger(**logger_cfg) |
|
else: |
|
logger_cfg = default_logger_cfg |
|
trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg) |
|
|
|
# config the strategy, defualt is ddp |
|
if "strategy" in trainer_config: |
|
strategy_cfg = trainer_config["strategy"] |
|
trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg) |
|
else: |
|
strategy_cfg = {"find_unused_parameters": False} |
|
trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg) |
|
|
|
# Set up ModelCheckpoint callback to save best models |
|
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to |
|
# specify which metric is used to determine best models |
|
default_modelckpt_cfg = { |
|
"dirpath": ckptdir, |
|
"filename": "{epoch:06}", |
|
"verbose": True, |
|
"save_last": True, |
|
} |
|
if hasattr(model, "monitor"): |
|
default_modelckpt_cfg["monitor"] = model.monitor |
|
default_modelckpt_cfg["save_top_k"] = 3 |
|
|
|
if "modelcheckpoint" in lightning_config: |
|
modelckpt_cfg = lightning_config.modelcheckpoint["params"] |
|
else: |
|
modelckpt_cfg = OmegaConf.create() |
|
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) |
|
if version.parse(pl.__version__) < version.parse("1.4.0"): |
|
trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg) |
|
|
|
# Create an empty OmegaConf configuration object |
|
|
|
callbacks_cfg = OmegaConf.create() |
|
|
|
# Instantiate items according to the configs |
|
trainer_kwargs.setdefault("callbacks", []) |
|
setup_callback_config = { |
|
"resume": opt.resume, # resume training if applicable |
|
"now": now, |
|
"logdir": logdir, # directory to save the log file |
|
"ckptdir": ckptdir, # directory to save the checkpoint file |
|
"cfgdir": cfgdir, # directory to save the configuration file |
|
"config": config, # configuration dictionary |
|
"lightning_config": lightning_config, # LightningModule configuration |
|
} |
|
trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config)) |
|
|
|
image_logger_config = { |
|
"batch_frequency": 750, # how frequently to log images |
|
"max_images": 4, # maximum number of images to log |
|
"clamp": True, # whether to clamp pixel values to [0,1] |
|
} |
|
trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config)) |
|
|
|
learning_rate_logger_config = { |
|
"logging_interval": "step", # logging frequency (either 'step' or 'epoch') |
|
# "log_momentum": True # whether to log momentum (currently commented out) |
|
} |
|
trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config)) |
|
|
|
metrics_over_trainsteps_checkpoint_config = { |
|
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), |
|
"filename": "{epoch:06}-{step:09}", |
|
"verbose": True, |
|
"save_top_k": -1, |
|
"every_n_train_steps": 10000, |
|
"save_weights_only": True, |
|
} |
|
trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config)) |
|
trainer_kwargs["callbacks"].append(CUDACallback()) |
|
|
|
# Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory |
|
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) |
|
trainer.logdir = logdir |
|
|
|
# Create a data module based on the configuration file |
|
data = DataModuleFromConfig(**config.data) |
|
|
|
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html |
|
# calling these ourselves should not be necessary but it is. |
|
# lightning still takes care of proper multiprocessing though |
|
data.prepare_data() |
|
data.setup() |
|
|
|
# Print some information about the datasets in the data module |
|
for k in data.datasets: |
|
rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") |
|
|
|
# Configure learning rate based on the batch size, base learning rate and number of GPUs |
|
# If scale_lr is true, calculate the learning rate based on additional factors |
|
bs, base_lr = config.data.batch_size, config.model.base_learning_rate |
|
if not cpu: |
|
ngpu = trainer_config["devices"] |
|
else: |
|
ngpu = 1 |
|
if "accumulate_grad_batches" in lightning_config.trainer: |
|
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches |
|
else: |
|
accumulate_grad_batches = 1 |
|
rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}") |
|
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches |
|
if opt.scale_lr: |
|
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr |
|
rank_zero_info( |
|
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( |
|
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr |
|
) |
|
) |
|
else: |
|
model.learning_rate = base_lr |
|
rank_zero_info("++++ NOT USING LR SCALING ++++") |
|
rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}") |
|
|
|
# Allow checkpointing via USR1 |
|
def melk(*args, **kwargs): |
|
# run all checkpoint hooks |
|
if trainer.global_rank == 0: |
|
print("Summoning checkpoint.") |
|
ckpt_path = os.path.join(ckptdir, "last.ckpt") |
|
trainer.save_checkpoint(ckpt_path) |
|
|
|
def divein(*args, **kwargs): |
|
if trainer.global_rank == 0: |
|
import pudb |
|
|
|
pudb.set_trace() |
|
|
|
import signal |
|
|
|
# Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal |
|
signal.signal(signal.SIGUSR1, melk) |
|
signal.signal(signal.SIGUSR2, divein) |
|
|
|
# Run the training and validation |
|
if opt.train: |
|
try: |
|
trainer.fit(model, data) |
|
except Exception: |
|
melk() |
|
raise |
|
# Print the maximum GPU memory allocated during training |
|
print(f"GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**2:.0f} MB") |
|
# if not opt.no_test and not trainer.interrupted: |
|
# trainer.test(model, data) |
|
except Exception: |
|
# If there's an exception, debug it if opt.debug is true and the trainer's global rank is 0 |
|
if opt.debug and trainer.global_rank == 0: |
|
try: |
|
import pudb as debugger |
|
except ImportError: |
|
import pdb as debugger |
|
debugger.post_mortem() |
|
raise |
|
finally: |
|
# Move the log directory to debug_runs if opt.debug is true and the trainer's global |
|
if opt.debug and not opt.resume and trainer.global_rank == 0: |
|
dst, name = os.path.split(logdir) |
|
dst = os.path.join(dst, "debug_runs", name) |
|
os.makedirs(os.path.split(dst)[0], exist_ok=True) |
|
os.rename(logdir, dst) |
|
if trainer.global_rank == 0: |
|
print(trainer.profiler.summary())
|
|
|