@ -10,11 +10,8 @@ import time
import numpy as np
import torch
import torchvision
import lightning . pytorch as pl
try :
import lightning . pytorch as pl
except :
import pytorch_lightning as pl
from functools import partial
@ -23,19 +20,15 @@ from packaging import version
from PIL import Image
from prefetch_generator import BackgroundGenerator
from torch . utils . data import DataLoader , Dataset , Subset , random_split
from ldm . models . diffusion . ddpm import LatentDiffusion
try :
from lightning . pytorch import seed_everything
from lightning . pytorch . callbacks import Callback , LearningRateMonitor , ModelCheckpoint
from lightning . pytorch . trainer import Trainer
from lightning . pytorch . utilities import rank_zero_info , rank_zero_only
LIGHTNING_PACK_NAME = " lightning.pytorch. "
except :
from pytorch_lightning import seed_everything
from pytorch_lightning . callbacks import Callback , LearningRateMonitor , ModelCheckpoint
from pytorch_lightning . trainer import Trainer
from pytorch_lightning . utilities import rank_zero_info , rank_zero_only
LIGHTNING_PACK_NAME = " pytorch_lightning. "
from lightning . pytorch import seed_everything
from lightning . pytorch . callbacks import Callback , LearningRateMonitor , ModelCheckpoint
from lightning . pytorch . trainer import Trainer
from lightning . pytorch . utilities import rank_zero_info , rank_zero_only
from lightning . pytorch . loggers import WandbLogger , TensorBoardLogger
from lightning . pytorch . strategies import ColossalAIStrategy , DDPStrategy
LIGHTNING_PACK_NAME = " lightning.pytorch. "
from ldm . data . base import Txt2ImgIterableBaseDataset
from ldm . util import instantiate_from_config
@ -687,153 +680,114 @@ if __name__ == "__main__":
config . model [ " params " ] . update ( { " ckpt " : ckpt } )
rank_zero_info ( " Using ckpt_path = {} " . format ( config . model [ " params " ] [ " ckpt " ] ) )
model = instantiate_from_config ( config . model )
model = LatentDiffusion ( * * config . model . get ( " params " , dict ( ) ) )
# trainer and callbacks
trainer_kwargs = dict ( )
# config the logger
# Default logger configs to log training metrics during the training process.
# These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger.
default_logger_cfgs = {
" wandb " : {
" target " : LIGHTNING_PACK_NAME + " loggers.WandbLogger " ,
" params " : {
" name " : nowname ,
" save_dir " : logdir ,
" offline " : opt . debug ,
" id " : nowname ,
}
} ,
,
" tensorboard " : {
" target " : LIGHTNING_PACK_NAME + " loggers.TensorBoardLogger " ,
" params " : {
" save_dir " : logdir ,
" name " : " diff_tb " ,
" log_graph " : True
}
}
}
# Set up the logger for TensorBoard
default_logger_cfg = default_logger_cfgs [ " tensorboard " ]
if " logger " in lightning_config :
logger_cfg = lightning_config . logger
trainer_kwargs [ " logger " ] = WandbLogger ( * * logger_cfg )
else :
logger_cfg = default_logger_cfg
logger_cfg = OmegaConf . merge ( default_logger_cfg , logger_cfg )
trainer_kwargs [ " logger " ] = instantiate_from_config ( logger_cfg )
trainer_kwargs [ " logger " ] = TensorBoardLogger ( * * logger_cfg )
# config the strategy, defualt is ddp
if " strategy " in trainer_config :
strategy_cfg = trainer_config [ " strategy " ]
strategy_cfg [ " targe t " ] = LIGHTNING_PACK_NAME + strategy_cfg [ " target " ]
trainer_kwargs [ " s tr ategy " ] = ColossalAIStrategy ( * * strategy_cfg )
else :
strategy_cfg = {
" target " : LIGHTNING_PACK_NAME + " strategies.DDPStrategy " ,
" params " : {
" find_unused_parameters " : False
}
}
trainer_kwargs [ " strategy " ] = instantiate_from_config ( strategy_cfg )
strategy_cfg = { " find_unused_parameters " : False }
trainer_kwargs [ " strategy " ] = DDPStrategy ( * * strategy_cfg )
# Set up ModelCheckpoint callback to save best models
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
" target " : LIGHTNING_PACK_NAME + " callbacks.ModelCheckpoint " ,
" params " : {
" dirpath " : ckptdir ,
" filename " : " {epoch:06} " ,
" verbose " : True ,
" save_last " : True ,
}
}
if hasattr ( model , " monitor " ) :
default_modelckpt_cfg [ " params " ] [ " monitor" ] = model . monitor
default_modelckpt_cfg [ " params " ] [ " save_top_k" ] = 3
default_modelckpt_cfg [ " monitor " ] = model . monitor
default_modelckpt_cfg [ " save_top_k " ] = 3
if " modelcheckpoint " in lightning_config :
modelckpt_cfg = lightning_config . modelcheckpoint
modelckpt_cfg = lightning_config . modelcheckpoint [ " params " ]
else :
modelckpt_cfg = OmegaConf . create ( )
modelckpt_cfg = OmegaConf . merge ( default_modelckpt_cfg , modelckpt_cfg )
if version . parse ( pl . __version__ ) < version . parse ( ' 1.4.0 ' ) :
trainer_kwargs [ " checkpoint_callback " ] = instantiate_from_config ( modelckpt_cfg )
# Set up various callbacks, including logging, learning rate monitoring, and CUDA management
# add callback which sets up log directory
default_callbacks_cfg = {
" setup_callback " : { # callback to set up the training
" target " : " main.SetupCallback " ,
" params " : {
" resume " : opt . resume , # resume training if applicable
" now " : now ,
" logdir " : logdir , # directory to save the log file
" ckptdir " : ckptdir , # directory to save the checkpoint file
" cfgdir " : cfgdir , # directory to save the configuration file
" config " : config , # configuration dictionary
" lightning_config " : lightning_config , # LightningModule configuration
}
} ,
" image_logger " : { # callback to log image data
" target " : " main.ImageLogger " ,
" params " : {
" batch_frequency " : 750 , # how frequently to log images
" max_images " : 4 , # maximum number of images to log
" clamp " : True # whether to clamp pixel values to [0,1]
}
} ,
" learning_rate_logger " : { # callback to log learning rate
" target " : " main.LearningRateMonitor " ,
" params " : {
" logging_interval " : " step " , # logging frequency (either 'step' or 'epoch')
# "log_momentum": True # whether to log momentum (currently commented out)
}
} ,
" cuda_callback " : { # callback to handle CUDA-related operations
" target " : " main.CUDACallback "
} ,
}
# If the LightningModule configuration has specified callbacks, use those
# Otherwise, create an empty OmegaConf configuration object
if " callbacks " in lightning_config :
callbacks_cfg = lightning_config . callbacks
else :
callbacks_cfg = OmegaConf . create ( )
trainer_kwargs [ " checkpoint_callback " ] = ModelCheckpoint ( * * modelckpt_cfg )
#Create an empty OmegaConf configuration object
callbacks_cfg = OmegaConf . create ( )
#Instantiate items according to the configs
trainer_kwargs . setdefault ( " callbacks " , [ ] )
setup_callback_config = {
" resume " : opt . resume , # resume training if applicable
" now " : now ,
" logdir " : logdir , # directory to save the log file
" ckptdir " : ckptdir , # directory to save the checkpoint file
" cfgdir " : cfgdir , # directory to save the configuration file
" config " : config , # configuration dictionary
" lightning_config " : lightning_config , # LightningModule configuration
}
trainer_kwargs [ " callbacks " ] . append ( SetupCallback ( * * setup_callback_config ) )
# If the 'metrics_over_trainsteps_checkpoint' callback is specified in the
# LightningModule configuration, update the default callbacks configuration
if ' metrics_over_trainsteps_checkpoint ' in callbacks_cfg :
print (
' Caution: Saving checkpoints every n train steps without deleting. This might require some free space. ' )
default_metrics_over_trainsteps_ckpt_dict = {
' metrics_over_trainsteps_checkpoint ' : {
" target " : LIGHTNING_PACK_NAME + ' callbacks.ModelCheckpoint ' ,
' params ' : {
" dirpath " : os . path . join ( ckptdir , ' trainstep_checkpoints ' ) ,
" filename " : " {epoch:06} - {step:09} " ,
" verbose " : True ,
' save_top_k ' : - 1 ,
' every_n_train_steps ' : 10000 ,
' save_weights_only ' : True
}
}
image_logger_config = {
" batch_frequency " : 750 , # how frequently to log images
" max_images " : 4 , # maximum number of images to log
" clamp " : True # whether to clamp pixel values to [0,1]
}
default_callbacks_cfg . update ( default_metrics_over_trainsteps_ckpt_dict )
trainer_kwargs [ " callbacks " ] . append ( ImageLogger ( * * image_logger_config ) )
# Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks
callbacks_cfg = OmegaConf . merge ( default_callbacks_cfg , callbacks_cfg )
trainer_kwargs [ " callbacks " ] = [ instantiate_from_config ( callbacks_cfg [ k ] ) for k in callbacks_cfg ]
learning_rate_logger_config = {
" logging_interval " : " step " , # logging frequency (either 'step' or 'epoch')
# "log_momentum": True # whether to log momentum (currently commented out)
}
trainer_kwargs [ " callbacks " ] . append ( LearningRateMonitor ( * * learning_rate_logger_config ) )
metrics_over_trainsteps_checkpoint_config = {
" dirpath " : os . path . join ( ckptdir , ' trainstep_checkpoints ' ) ,
" filename " : " {epoch:06} - {step:09} " ,
" verbose " : True ,
' save_top_k ' : - 1 ,
' every_n_train_steps ' : 10000 ,
' save_weights_only ' : True
}
trainer_kwargs [ " callbacks " ] . append ( ModelCheckpoint ( * * metrics_over_trainsteps_checkpoint_config ) )
trainer_kwargs [ " callbacks " ] . append ( CUDACallback ( ) )
# Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory
trainer = Trainer . from_argparse_args ( trainer_opt , * * trainer_kwargs )
trainer . logdir = logdir
# Create a data module based on the configuration file
data = instantiate_from_config ( config . data )
data = DataModuleFromConfig ( * * config . data )
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
@ -846,7 +800,7 @@ if __name__ == "__main__":
# Configure learning rate based on the batch size, base learning rate and number of GPUs
# If scale_lr is true, calculate the learning rate based on additional factors
bs , base_lr = config . data . params . batch_size , config . model . base_learning_rate
bs , base_lr = config . data . batch_size , config . model . base_learning_rate
if not cpu :
ngpu = trainer_config [ " devices " ]
else :