From 48d33b1b1753f19361e7e54a68a7ac5999dc02e4 Mon Sep 17 00:00:00 2001 From: HELSON Date: Fri, 6 Jan 2023 13:41:19 +0800 Subject: [PATCH] [gemini] add get static torch model (#2356) --- colossalai/nn/parallel/data_parallel.py | 14 +- colossalai/nn/parallel/utils.py | 93 +++++++++-- .../dreambooth/train_dreambooth_colossalai.py | 148 ++++++++---------- ...orch_module.py => test_get_torch_model.py} | 27 ++-- 4 files changed, 164 insertions(+), 118 deletions(-) rename tests/test_gemini/update/{test_convert_torch_module.py => test_get_torch_model.py} (60%) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index cbef6f532..e3bb83347 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -389,19 +389,6 @@ class ZeroDDP(ColoDDP): del temp_chunk return param_to_save_data - def torch_named_parameters(self): - """ - get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload. - It works the same as torch.Module named_parameters - """ - params_list = [p for p in self.parameters(recurse=True)] - param_to_save_data = self._get_param_to_save_data(params_list, False) - for (name, _), p in zip(self.named_parameters(recurse=True), params_list): - if p is not None: - assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - record_parameter = param_to_save_data[p] - yield name, record_parameter - def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every @@ -418,6 +405,7 @@ class ZeroDDP(ColoDDP): assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) + # TODO: (HELSON) deal with ddp ignored parameters for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index e514146ce..1205cbc3a 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -1,5 +1,10 @@ +from collections import OrderedDict +from copy import copy +from typing import Optional, Set + import torch import torch.distributed as dist +import torch.nn as nn from colossalai.gemini.chunk import Chunk from colossalai.utils import get_current_device @@ -21,30 +26,88 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk): return total_temp -# TODO() not work for module where two params share the same tensor. -def _add_param(model, name, param): - name_list = name.split('.') - module = model._modules[name_list[0]] - for i in range(1, len(name_list) - 1): - module = module._modules[name_list[i]] - module._parameters[name_list[-1]] = param +def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''): + """Get a dfs module list of the given module. Its order is same as the order of creations of modules. + """ + if memo is None: + memo = set() + if module not in memo: + for name, submodule in module._modules.items(): + if submodule is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + for m in _get_dfs_module_list(submodule, memo, submodule_prefix): + yield m + + memo.add(module) + yield prefix, module -def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module: - """convert_to_torch_module +def _get_shallow_copy_model(model: nn.Module): + """Get a shallow copy of the given model. Each submodule is different from the original submodule. + But the new submodule and the old submodule share all attributes. + """ + name_to_module = dict() + for name, module in _get_dfs_module_list(model): + new_module = copy(module) + new_module._modules = OrderedDict() + for subname, submodule in module._modules.items(): + if submodule is None: + continue + full_name = name + ('.' if name else '') + subname + setattr(new_module, subname, name_to_module[full_name]) + name_to_module[name] = new_module + return name_to_module[''] + + +def get_static_torch_model(gemini_ddp_model, + device=torch.device("cpu"), + dtype=torch.float32, + only_rank_0=True) -> torch.nn.Module: + """Get a static torch.nn.Module model from the given GeminiDDP module. + You should notice that the original GeminiDDP model is not modified. + Thus, you can use the original model in further training. + But you should not use the returned torch model to train, this can cause unexpected errors. Args: gemini_ddp_model (GeminiDDP): a gemini ddp model + device (torch.device): the device of the final torch model + dtype (torch.dtype): the dtype of the final torch model + only_rank_0 (bool): if True, only rank0 has the coverted torch model Returns: - torch.nn.Module: a torch model contains the params of gemini_ddp_model + torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ from colossalai.nn.parallel import GeminiDDP assert isinstance(gemini_ddp_model, GeminiDDP) - module = gemini_ddp_model.module - # replace ColoTensor to torch.nn.Tensor in module - for n, p in gemini_ddp_model.torch_named_parameters(): - _add_param(module, n, p) + state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0) + colo_model = gemini_ddp_model.module + torch_model = _get_shallow_copy_model(colo_model) - return module + if not only_rank_0 or dist.get_rank() == 0: + # record the mapping relationship between colo parameters and torch parameters + colo_to_torch = dict() + for (name, colo_module), (_, torch_module) in \ + zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)): + # clean the parameter list of the new torch module + torch_module._parameters = OrderedDict() + for sufix_param_name, param in colo_module.named_parameters(recurse=False): + # get the full name of the parameter + full_param_name = name + ('.' if name else '') + sufix_param_name + + if full_param_name not in state_dict: + # this means the parameter is shared by multiple modules + # we should use colo_to_torch to get the torch parameter created before + assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module" + torch_param = colo_to_torch[param] + else: + # we meet the parameter the first time, just use the state dict to get the data + state_param = state_dict[full_param_name] + torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype)) + colo_to_torch[param] = torch_param + + setattr(torch_module, sufix_param_name, torch_param) + dist.barrier() + + return torch_model diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index b95353d9b..b7e24bfe4 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -8,25 +8,23 @@ from typing import Optional import torch import torch.nn.functional as F import torch.utils.checkpoint +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel.utils import convert_to_torch_module -from colossalai.tensor import ProcessGroup +from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.optimization import get_scheduler -from huggingface_hub import HfFolder, Repository, whoami -from PIL import Image -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, PretrainedConfig - disable_existing_loggers() logger = get_dist_logger() @@ -112,10 +110,8 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=( - "Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt." - ), + help=("Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt."), ) parser.add_argument( "--output_dir", @@ -128,10 +124,8 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), + help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution"), ) parser.add_argument( "--placement", @@ -139,15 +133,14 @@ def parse_args(input_args=None): default="cpu", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", ) - parser.add_argument( - "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" - ) - parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." - ) - parser.add_argument( - "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." - ) + parser.add_argument("--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution") + parser.add_argument("--train_batch_size", + type=int, + default=4, + help="Batch size (per device) for the training dataloader.") + parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--max_train_steps", @@ -183,17 +176,16 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]'), ) + parser.add_argument("--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument("--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -208,10 +200,8 @@ def parse_args(input_args=None): "--logging_dir", type=str, default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), + help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), ) parser.add_argument( "--mixed_precision", @@ -221,8 +211,7 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -288,14 +277,12 @@ class DreamBoothDataset(Dataset): else: self.class_data_root = None - self.image_transforms = transforms.Compose( - [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) + self.image_transforms = transforms.Compose([ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ]) def __len__(self): return self._length @@ -356,26 +343,19 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: # Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): +def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP( - model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32 - ) + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=placememt_policy, + pin_memory=True, + search_range_mb=64) return model def main(args): - # config for colossalai - - config = { - "BATCH": args.train_batch_size, - "gradient_accumulation_steps": args.gradient_accumulation_steps, - "clip_grad_norm": args.max_grad_norm, - } - - colossalai.launch_from_torch(config=config) - pg = ProcessGroup() + colossalai.launch_from_torch(config={}) if args.seed is not None: gpc.set_seed(args.seed) @@ -405,9 +385,9 @@ def main(args): pipeline.to(get_current_device()) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not gpc.get_local_rank(ParallelMode.DATA) == 0, + sample_dataloader, + desc="Generating class images", + disable=not gpc.get_local_rank(ParallelMode.DATA) == 0, ): images = pipeline(example["prompt"]).images @@ -472,10 +452,11 @@ def main(args): ) logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) - with ColoInitContext(): - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False - ) + with ColoInitContext(device=get_current_device()): + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + low_cpu_mem_usage=False) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -486,10 +467,10 @@ def main(args): if args.scale_lr: args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) - unet = gemini_zero_dpp(unet, pg, args.placement) + unet = gemini_zero_dpp(unet, args.placement) # config optimizer for colossalai zero - optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5) + optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") @@ -520,7 +501,9 @@ def main(args): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - {"input_ids": input_ids}, + { + "input_ids": input_ids + }, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -532,9 +515,11 @@ def main(args): } return batch - train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 - ) + train_dataloader = torch.utils.data.DataLoader(train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=1) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -652,15 +637,16 @@ def main(args): logs = { "loss": loss.detach().item(), "lr": optimizer.param_groups[0]["lr"], - } # lr_scheduler.get_last_lr()[0]} + } # lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step % args.save_steps == 0: torch.cuda.synchronize() + torch_unet = get_static_torch_model(unet) if gpc.get_local_rank(ParallelMode.DATA) == 0: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=convert_to_torch_module(unet), + unet=torch_unet, revision=args.revision, ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") @@ -670,7 +656,7 @@ def main(args): break torch.cuda.synchronize() - unet = convert_to_torch_module(unet) + unet = get_static_torch_model(unet) if gpc.get_local_rank(ParallelMode.DATA) == 0: pipeline = DiffusionPipeline.from_pretrained( diff --git a/tests/test_gemini/update/test_convert_torch_module.py b/tests/test_gemini/update/test_get_torch_model.py similarity index 60% rename from tests/test_gemini/update/test_convert_torch_module.py rename to tests/test_gemini/update/test_get_torch_model.py index 160099167..e6d586b37 100644 --- a/tests/test_gemini/update/test_convert_torch_module.py +++ b/tests/test_gemini/update/test_get_torch_model.py @@ -6,8 +6,9 @@ import torch import torch.multiprocessing as mp import colossalai -from colossalai.nn.parallel.utils import convert_to_torch_module -from colossalai.tensor import ColoTensor +from colossalai.nn.parallel import GeminiDDP +from colossalai.nn.parallel.utils import get_static_torch_model +from colossalai.tensor import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device @@ -15,21 +16,29 @@ from colossalai.utils.model.colo_init_context import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs -@parameterize('model_name', ['resnet18', 'bert']) +@parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2']) def run_convert_torch_module(model_name: str): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, _, _, _, _ = get_components_func() - with ColoInitContext(device='cpu'): + with ColoInitContext(device=torch.device("cpu")): model = model_builder(checkpoint=False) - - from colossalai.nn.parallel import GeminiDDP model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True) - - pytorch_model = convert_to_torch_module(model) + pytorch_model = get_static_torch_model(model, only_rank_0=False) for n, p in pytorch_model.named_parameters(): - assert not isinstance(p, ColoTensor) + assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}" + + # get the static model should not change the original model + for n, p in model.named_parameters(): + assert isinstance(p, ColoParameter) + + for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()): + assert pn == cn + assert id(pm) != id(cm) + for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)): + assert id(pp) != id(cp) + assert pp.shape == cp.shape def run_dist(rank, world_size, port):