mirror of https://github.com/hpcaitech/ColossalAI
[gemini] add get static torch model (#2356)
parent
7a332b1734
commit
48d33b1b17
|
@ -389,19 +389,6 @@ class ZeroDDP(ColoDDP):
|
||||||
del temp_chunk
|
del temp_chunk
|
||||||
return param_to_save_data
|
return param_to_save_data
|
||||||
|
|
||||||
def torch_named_parameters(self):
|
|
||||||
"""
|
|
||||||
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
|
|
||||||
It works the same as torch.Module named_parameters
|
|
||||||
"""
|
|
||||||
params_list = [p for p in self.parameters(recurse=True)]
|
|
||||||
param_to_save_data = self._get_param_to_save_data(params_list, False)
|
|
||||||
for (name, _), p in zip(self.named_parameters(recurse=True), params_list):
|
|
||||||
if p is not None:
|
|
||||||
assert p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
|
||||||
record_parameter = param_to_save_data[p]
|
|
||||||
yield name, record_parameter
|
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||||
r"""Saves module state to `destination` dictionary, containing a state
|
r"""Saves module state to `destination` dictionary, containing a state
|
||||||
of the module, but not its descendants. This is called on every
|
of the module, but not its descendants. This is called on every
|
||||||
|
@ -418,6 +405,7 @@ class ZeroDDP(ColoDDP):
|
||||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||||
|
|
||||||
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
|
||||||
|
# TODO: (HELSON) deal with ddp ignored parameters
|
||||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||||
if p is not None:
|
if p is not None:
|
||||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||||
|
|
|
@ -1,5 +1,10 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
from copy import copy
|
||||||
|
from typing import Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.gemini.chunk import Chunk
|
from colossalai.gemini.chunk import Chunk
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -21,30 +26,88 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||||
return total_temp
|
return total_temp
|
||||||
|
|
||||||
|
|
||||||
# TODO() not work for module where two params share the same tensor.
|
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''):
|
||||||
def _add_param(model, name, param):
|
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
|
||||||
name_list = name.split('.')
|
"""
|
||||||
module = model._modules[name_list[0]]
|
if memo is None:
|
||||||
for i in range(1, len(name_list) - 1):
|
memo = set()
|
||||||
module = module._modules[name_list[i]]
|
if module not in memo:
|
||||||
module._parameters[name_list[-1]] = param
|
for name, submodule in module._modules.items():
|
||||||
|
if submodule is None:
|
||||||
|
continue
|
||||||
|
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||||
|
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
|
||||||
|
yield m
|
||||||
|
|
||||||
|
memo.add(module)
|
||||||
|
yield prefix, module
|
||||||
|
|
||||||
|
|
||||||
def convert_to_torch_module(gemini_ddp_model: 'GeminiDDP') -> torch.nn.Module:
|
def _get_shallow_copy_model(model: nn.Module):
|
||||||
"""convert_to_torch_module
|
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
|
||||||
|
But the new submodule and the old submodule share all attributes.
|
||||||
|
"""
|
||||||
|
name_to_module = dict()
|
||||||
|
for name, module in _get_dfs_module_list(model):
|
||||||
|
new_module = copy(module)
|
||||||
|
new_module._modules = OrderedDict()
|
||||||
|
for subname, submodule in module._modules.items():
|
||||||
|
if submodule is None:
|
||||||
|
continue
|
||||||
|
full_name = name + ('.' if name else '') + subname
|
||||||
|
setattr(new_module, subname, name_to_module[full_name])
|
||||||
|
name_to_module[name] = new_module
|
||||||
|
return name_to_module['']
|
||||||
|
|
||||||
|
|
||||||
|
def get_static_torch_model(gemini_ddp_model,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
only_rank_0=True) -> torch.nn.Module:
|
||||||
|
"""Get a static torch.nn.Module model from the given GeminiDDP module.
|
||||||
|
You should notice that the original GeminiDDP model is not modified.
|
||||||
|
Thus, you can use the original model in further training.
|
||||||
|
But you should not use the returned torch model to train, this can cause unexpected errors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gemini_ddp_model (GeminiDDP): a gemini ddp model
|
gemini_ddp_model (GeminiDDP): a gemini ddp model
|
||||||
|
device (torch.device): the device of the final torch model
|
||||||
|
dtype (torch.dtype): the dtype of the final torch model
|
||||||
|
only_rank_0 (bool): if True, only rank0 has the coverted torch model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.nn.Module: a torch model contains the params of gemini_ddp_model
|
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
|
||||||
"""
|
"""
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
from colossalai.nn.parallel import GeminiDDP
|
||||||
assert isinstance(gemini_ddp_model, GeminiDDP)
|
assert isinstance(gemini_ddp_model, GeminiDDP)
|
||||||
module = gemini_ddp_model.module
|
|
||||||
|
|
||||||
# replace ColoTensor to torch.nn.Tensor in module
|
state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
|
||||||
for n, p in gemini_ddp_model.torch_named_parameters():
|
colo_model = gemini_ddp_model.module
|
||||||
_add_param(module, n, p)
|
torch_model = _get_shallow_copy_model(colo_model)
|
||||||
|
|
||||||
return module
|
if not only_rank_0 or dist.get_rank() == 0:
|
||||||
|
# record the mapping relationship between colo parameters and torch parameters
|
||||||
|
colo_to_torch = dict()
|
||||||
|
for (name, colo_module), (_, torch_module) in \
|
||||||
|
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
|
||||||
|
# clean the parameter list of the new torch module
|
||||||
|
torch_module._parameters = OrderedDict()
|
||||||
|
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
|
||||||
|
# get the full name of the parameter
|
||||||
|
full_param_name = name + ('.' if name else '') + sufix_param_name
|
||||||
|
|
||||||
|
if full_param_name not in state_dict:
|
||||||
|
# this means the parameter is shared by multiple modules
|
||||||
|
# we should use colo_to_torch to get the torch parameter created before
|
||||||
|
assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
|
||||||
|
torch_param = colo_to_torch[param]
|
||||||
|
else:
|
||||||
|
# we meet the parameter the first time, just use the state dict to get the data
|
||||||
|
state_param = state_dict[full_param_name]
|
||||||
|
torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
|
||||||
|
colo_to_torch[param] = torch_param
|
||||||
|
|
||||||
|
setattr(torch_module, sufix_param_name, torch_param)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
return torch_model
|
||||||
|
|
|
@ -8,25 +8,23 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import AutoTokenizer, PretrainedConfig
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||||
from colossalai.nn.parallel.utils import convert_to_torch_module
|
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||||
from colossalai.tensor import ProcessGroup
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import AutoTokenizer, PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
@ -112,10 +110,8 @@ def parse_args(input_args=None):
|
||||||
"--num_class_images",
|
"--num_class_images",
|
||||||
type=int,
|
type=int,
|
||||||
default=100,
|
default=100,
|
||||||
help=(
|
help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||||
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
" class_data_dir, additional images will be sampled with class_prompt."),
|
||||||
" class_data_dir, additional images will be sampled with class_prompt."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_dir",
|
"--output_dir",
|
||||||
|
@ -128,10 +124,8 @@ def parse_args(input_args=None):
|
||||||
"--resolution",
|
"--resolution",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=512,
|
||||||
help=(
|
help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
" resolution"),
|
||||||
" resolution"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--placement",
|
"--placement",
|
||||||
|
@ -139,15 +133,14 @@ def parse_args(input_args=None):
|
||||||
default="cpu",
|
default="cpu",
|
||||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--center_crop",
|
||||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
action="store_true",
|
||||||
)
|
help="Whether to center crop images before resizing to resolution")
|
||||||
parser.add_argument(
|
parser.add_argument("--train_batch_size",
|
||||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
type=int,
|
||||||
)
|
default=4,
|
||||||
parser.add_argument(
|
help="Batch size (per device) for the training dataloader.")
|
||||||
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
|
||||||
)
|
|
||||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_train_steps",
|
"--max_train_steps",
|
||||||
|
@ -183,17 +176,16 @@ def parse_args(input_args=None):
|
||||||
"--lr_scheduler",
|
"--lr_scheduler",
|
||||||
type=str,
|
type=str,
|
||||||
default="constant",
|
default="constant",
|
||||||
help=(
|
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
' "constant", "constant_with_warmup"]'),
|
||||||
' "constant", "constant_with_warmup"]'
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--lr_warmup_steps",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Number of steps for the warmup in the lr scheduler.")
|
||||||
|
parser.add_argument("--use_8bit_adam",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to use 8-bit Adam from bitsandbytes.")
|
||||||
|
|
||||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||||
|
@ -208,10 +200,8 @@ def parse_args(input_args=None):
|
||||||
"--logging_dir",
|
"--logging_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="logs",
|
default="logs",
|
||||||
help=(
|
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
|
||||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mixed_precision",
|
"--mixed_precision",
|
||||||
|
@ -221,8 +211,7 @@ def parse_args(input_args=None):
|
||||||
help=(
|
help=(
|
||||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
|
||||||
),
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||||
|
|
||||||
|
@ -288,14 +277,12 @@ class DreamBoothDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
self.class_data_root = None
|
self.class_data_root = None
|
||||||
|
|
||||||
self.image_transforms = transforms.Compose(
|
self.image_transforms = transforms.Compose([
|
||||||
[
|
|
||||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize([0.5], [0.5]),
|
transforms.Normalize([0.5], [0.5]),
|
||||||
]
|
])
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._length
|
return self._length
|
||||||
|
@ -356,26 +343,19 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||||
|
|
||||||
|
|
||||||
# Gemini + ZeRO DDP
|
# Gemini + ZeRO DDP
|
||||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
from colossalai.nn.parallel import GeminiDDP
|
||||||
|
|
||||||
model = GeminiDDP(
|
model = GeminiDDP(model,
|
||||||
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32
|
device=get_current_device(),
|
||||||
)
|
placement_policy=placememt_policy,
|
||||||
|
pin_memory=True,
|
||||||
|
search_range_mb=64)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# config for colossalai
|
colossalai.launch_from_torch(config={})
|
||||||
|
|
||||||
config = {
|
|
||||||
"BATCH": args.train_batch_size,
|
|
||||||
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
|
||||||
"clip_grad_norm": args.max_grad_norm,
|
|
||||||
}
|
|
||||||
|
|
||||||
colossalai.launch_from_torch(config=config)
|
|
||||||
pg = ProcessGroup()
|
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
gpc.set_seed(args.seed)
|
gpc.set_seed(args.seed)
|
||||||
|
@ -472,10 +452,11 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
|
||||||
with ColoInitContext():
|
with ColoInitContext(device=get_current_device()):
|
||||||
unet = UNet2DConditionModel.from_pretrained(
|
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
|
||||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
|
subfolder="unet",
|
||||||
)
|
revision=args.revision,
|
||||||
|
low_cpu_mem_usage=False)
|
||||||
|
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
|
@ -486,10 +467,10 @@ def main(args):
|
||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
|
||||||
unet = gemini_zero_dpp(unet, pg, args.placement)
|
unet = gemini_zero_dpp(unet, args.placement)
|
||||||
|
|
||||||
# config optimizer for colossalai zero
|
# config optimizer for colossalai zero
|
||||||
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5)
|
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
|
||||||
|
|
||||||
# load noise_scheduler
|
# load noise_scheduler
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||||
|
@ -520,7 +501,9 @@ def main(args):
|
||||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||||
|
|
||||||
input_ids = tokenizer.pad(
|
input_ids = tokenizer.pad(
|
||||||
{"input_ids": input_ids},
|
{
|
||||||
|
"input_ids": input_ids
|
||||||
|
},
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=tokenizer.model_max_length,
|
max_length=tokenizer.model_max_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
@ -532,9 +515,11 @@ def main(args):
|
||||||
}
|
}
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(train_dataset,
|
||||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
|
batch_size=args.train_batch_size,
|
||||||
)
|
shuffle=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
num_workers=1)
|
||||||
|
|
||||||
# Scheduler and math around the number of training steps.
|
# Scheduler and math around the number of training steps.
|
||||||
overrode_max_train_steps = False
|
overrode_max_train_steps = False
|
||||||
|
@ -657,10 +642,11 @@ def main(args):
|
||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
torch_unet = get_static_torch_model(unet)
|
||||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path,
|
||||||
unet=convert_to_torch_module(unet),
|
unet=torch_unet,
|
||||||
revision=args.revision,
|
revision=args.revision,
|
||||||
)
|
)
|
||||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||||
|
@ -670,7 +656,7 @@ def main(args):
|
||||||
break
|
break
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
unet = convert_to_torch_module(unet)
|
unet = get_static_torch_model(unet)
|
||||||
|
|
||||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
|
|
@ -6,8 +6,9 @@ import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.nn.parallel.utils import convert_to_torch_module
|
from colossalai.nn.parallel import GeminiDDP
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||||
|
from colossalai.tensor import ColoParameter
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
@ -15,21 +16,29 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
@parameterize('model_name', ['resnet18', 'bert'])
|
@parameterize('model_name', ['hanging_param_model', 'resnet18', 'gpt2'])
|
||||||
def run_convert_torch_module(model_name: str):
|
def run_convert_torch_module(model_name: str):
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, _, _, _, _ = get_components_func()
|
model_builder, _, _, _, _ = get_components_func()
|
||||||
|
|
||||||
with ColoInitContext(device='cpu'):
|
with ColoInitContext(device=torch.device("cpu")):
|
||||||
model = model_builder(checkpoint=False)
|
model = model_builder(checkpoint=False)
|
||||||
|
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
|
||||||
model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True)
|
model = GeminiDDP(model, device=get_current_device(), placement_policy='auto', pin_memory=True)
|
||||||
|
pytorch_model = get_static_torch_model(model, only_rank_0=False)
|
||||||
pytorch_model = convert_to_torch_module(model)
|
|
||||||
|
|
||||||
for n, p in pytorch_model.named_parameters():
|
for n, p in pytorch_model.named_parameters():
|
||||||
assert not isinstance(p, ColoTensor)
|
assert type(p) == torch.nn.Parameter, f"type error: {n} is a {type(p)}"
|
||||||
|
|
||||||
|
# get the static model should not change the original model
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
assert isinstance(p, ColoParameter)
|
||||||
|
|
||||||
|
for (pn, pm), (cn, cm) in zip(pytorch_model.named_modules(), model.named_modules()):
|
||||||
|
assert pn == cn
|
||||||
|
assert id(pm) != id(cm)
|
||||||
|
for pp, cp in zip(pm.parameters(recurse=False), cm.parameters(recurse=False)):
|
||||||
|
assert id(pp) != id(cp)
|
||||||
|
assert pp.shape == cp.shape
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
Loading…
Reference in New Issue