[gemini] add get static torch model (#2356)

pull/2365/head^2
HELSON 2023-01-06 13:41:19 +08:00 committed by GitHub
parent 7a332b1734
commit 48d33b1b17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 164 additions and 118 deletions

View File

@ -389,19 +389,6 @@ class ZeroDDP(ColoDDP):
del temp_chunk del temp_chunk
return param_to_save_data return param_to_save_data
def torch_named_parameters(self):
"""
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
It works the same as torch.Module named_parameters
"""
params_list = [p for p in self.parameters(recurse=True)]
param_to_save_data = self._get_param_to_save_data(params_list, False)
for (name, _), p in zip(self.named_parameters(recurse=True), params_list):
if p is not None:
assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[p]
yield name, record_parameter
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every of the module, but not its descendants. This is called on every
@ -418,6 +405,7 @@ class ZeroDDP(ColoDDP):
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# TODO: (HELSON) deal with ddp ignored parameters
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None: if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)

View File

@ -1,5 +1,10 @@
from collections import OrderedDict
from copy import copy
from typing import Optional, Set
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import Chunk from colossalai.gemini.chunk import Chunk
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -21,30 +26,88 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
return total_temp return total_temp
# TODO() not work for module where two params share the same tensor. def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''):
def _add_param(model, name, param): """Get a dfs module list of the given module. Its order is same as the order of creations of modules.
name_list = name.split('.') """
module = model._modules[name_list[0]] if memo is None:
for i in range(1, len(name_list) - 1): memo = set()
module = module._modules[name_list[i]] if module not in memo:
module._parameters[name_list[-1]] = param for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
yield m
memo.add(module)
yield prefix, module
def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module: def _get_shallow_copy_model(model: nn.Module):
"""convert_to_torch_module """Get a shallow copy of the given model. Each submodule is different from the original submodule.
But the new submodule and the old submodule share all attributes.
"""
name_to_module = dict()
for name, module in _get_dfs_module_list(model):
new_module = copy(module)
new_module._modules = OrderedDict()
for subname, submodule in module._modules.items():
if submodule is None:
continue
full_name = name + ('.' if name else '') + subname
setattr(new_module, subname, name_to_module[full_name])
name_to_module[name] = new_module
return name_to_module['']
def get_static_torch_model(gemini_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given GeminiDDP module.
You should notice that the original GeminiDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args: Args:
gemini_ddp_model (GeminiDDP): a gemini ddp model gemini_ddp_model (GeminiDDP): a gemini ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model
Returns: Returns:
torch.nn.Module: a torch model contains the params of gemini_ddp_model torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
""" """
from colossalai.nn.parallel import GeminiDDP from colossalai.nn.parallel import GeminiDDP
assert isinstance(gemini_ddp_model, GeminiDDP) assert isinstance(gemini_ddp_model, GeminiDDP)
module = gemini_ddp_model.module
# replace ColoTensor to torch.nn.Tensor in module state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
for n, p in gemini_ddp_model.torch_named_parameters(): colo_model = gemini_ddp_model.module
_add_param(module, n, p) torch_model = _get_shallow_copy_model(colo_model)
return module if not only_rank_0 or dist.get_rank() == 0:
# record the mapping relationship between colo parameters and torch parameters
colo_to_torch = dict()
for (name, colo_module), (_, torch_module) in \
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
# clean the parameter list of the new torch module
torch_module._parameters = OrderedDict()
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
# get the full name of the parameter
full_param_name = name + ('.' if name else '') + sufix_param_name
if full_param_name not in state_dict:
# this means the parameter is shared by multiple modules
# we should use colo_to_torch to get the torch parameter created before
assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
torch_param = colo_to_torch[param]
else:
# we meet the parameter the first time, just use the state dict to get the data
state_param = state_dict[full_param_name]
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
colo_to_torch[param] = torch_param
setattr(torch_module, sufix_param_name, torch_param)
dist.barrier()
return torch_model

View File

@ -8,25 +8,23 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import colossalai import colossalai
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import convert_to_torch_module from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
@ -112,10 +110,8 @@ def parse_args(input_args=None):
"--num_class_images", "--num_class_images",
type=int, type=int,
default=100, default=100,
help=( help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
"Minimal class images for prior preservation loss. If there are not enough images already present in" " class_data_dir, additional images will be sampled with class_prompt."),
" class_data_dir, additional images will be sampled with class_prompt."
),
) )
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
@ -128,10 +124,8 @@ def parse_args(input_args=None):
"--resolution", "--resolution",
type=int, type=int,
default=512, default=512,
help=( help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
"The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution"),
" resolution"
),
) )
parser.add_argument( parser.add_argument(
"--placement", "--placement",
@ -139,15 +133,14 @@ def parse_args(input_args=None):
default="cpu", default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
) )
parser.add_argument( parser.add_argument("--center_crop",
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" action="store_true",
) help="Whether to center crop images before resizing to resolution")
parser.add_argument( parser.add_argument("--train_batch_size",
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." type=int,
) default=4,
parser.add_argument( help="Batch size (per device) for the training dataloader.")
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
)
parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument( parser.add_argument(
"--max_train_steps", "--max_train_steps",
@ -183,17 +176,16 @@ def parse_args(input_args=None):
"--lr_scheduler", "--lr_scheduler",
type=str, type=str,
default="constant", default="constant",
help=( help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]'),
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
) )
parser.add_argument("--lr_warmup_steps",
type=int,
default=500,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
@ -208,10 +200,8 @@ def parse_args(input_args=None):
"--logging_dir", "--logging_dir",
type=str, type=str,
default="logs", default="logs",
help=( help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
) )
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
@ -221,8 +211,7 @@ def parse_args(input_args=None):
help=( help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
),
) )
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@ -288,14 +277,12 @@ class DreamBoothDataset(Dataset):
else: else:
self.class_data_root = None self.class_data_root = None
self.image_transforms = transforms.Compose( self.image_transforms = transforms.Compose([
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), transforms.Normalize([0.5], [0.5]),
] ])
)
def __len__(self): def __len__(self):
return self._length return self._length
@ -356,26 +343,19 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP( model = GeminiDDP(model,
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32 device=get_current_device(),
) placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=64)
return model return model
def main(args): def main(args):
# config for colossalai colossalai.launch_from_torch(config={})
config = {
"BATCH": args.train_batch_size,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"clip_grad_norm": args.max_grad_norm,
}
colossalai.launch_from_torch(config=config)
pg = ProcessGroup()
if args.seed is not None: if args.seed is not None:
gpc.set_seed(args.seed) gpc.set_seed(args.seed)
@ -472,10 +452,11 @@ def main(args):
) )
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
with ColoInitContext(): with ColoInitContext(device=get_current_device()):
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False subfolder="unet",
) revision=args.revision,
low_cpu_mem_usage=False)
vae.requires_grad_(False) vae.requires_grad_(False)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
@ -486,10 +467,10 @@ def main(args):
if args.scale_lr: if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
unet = gemini_zero_dpp(unet, pg, args.placement) unet = gemini_zero_dpp(unet, args.placement)
# config optimizer for colossalai zero # config optimizer for colossalai zero
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5) optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
# load noise_scheduler # load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
@ -520,7 +501,9 @@ def main(args):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad( input_ids = tokenizer.pad(
{"input_ids": input_ids}, {
"input_ids": input_ids
},
padding="max_length", padding="max_length",
max_length=tokenizer.model_max_length, max_length=tokenizer.model_max_length,
return_tensors="pt", return_tensors="pt",
@ -532,9 +515,11 @@ def main(args):
} }
return batch return batch
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(train_dataset,
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 batch_size=args.train_batch_size,
) shuffle=True,
collate_fn=collate_fn,
num_workers=1)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
@ -657,10 +642,11 @@ def main(args):
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
torch.cuda.synchronize() torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0: if gpc.get_local_rank(ParallelMode.DATA) == 0:
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=convert_to_torch_module(unet), unet=torch_unet,
revision=args.revision, revision=args.revision,
) )
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
@ -670,7 +656,7 @@ def main(args):
break break
torch.cuda.synchronize() torch.cuda.synchronize()
unet = convert_to_torch_module(unet) unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0: if gpc.get_local_rank(ParallelMode.DATA) == 0:
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(

View File

@ -6,8 +6,9 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.nn.parallel.utils import convert_to_torch_module from colossalai.nn.parallel import GeminiDDP
from colossalai.tensor import ColoTensor from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.tensor import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
@ -15,21 +16,29 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@parameterize('model_name', ['resnet18', 'bert']) @parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2'])
def run_convert_torch_module(model_name: str): def run_convert_torch_module(model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, _, _, _, _ = get_components_func() model_builder, _, _, _, _ = get_components_func()
with ColoInitContext(device='cpu'): with ColoInitContext(device=torch.device("cpu")):
model = model_builder(checkpoint=False) model = model_builder(checkpoint=False)
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True) model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True)
pytorch_model = get_static_torch_model(model, only_rank_0=False)
pytorch_model = convert_to_torch_module(model)
for n, p in pytorch_model.named_parameters(): for n, p in pytorch_model.named_parameters():
assert not isinstance(p, ColoTensor) assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}"
# get the static model should not change the original model
for n, p in model.named_parameters():
assert isinstance(p, ColoParameter)
for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()):
assert pn == cn
assert id(pm) != id(cm)
for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)):
assert id(pp) != id(cp)
assert pp.shape == cp.shape
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):