[shardformer] support from_pretrained when loading model with HybridParallelPlugin (#4575)

* hybrid plugin support huggingface from_pretrained

* add huggingface compatibility tests

* add folder cleaning

* fix bugs
pull/4619/head
Baizhou Zhang 2023-09-01 17:40:01 +08:00 committed by GitHub
parent c9625dbb63
commit 38ccb8b1a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 218 additions and 17 deletions

View File

@ -208,7 +208,7 @@ jobs:
- name: Execute Unit Testing
run: |
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-cov=. --durations=10 tests/
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1

View File

@ -141,10 +141,10 @@ def get_param_info(optim: Optimizer):
def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters())
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group['params'] if p in params]
params = [p for p in group['params'] if p in model_params]
new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups})

View File

@ -26,6 +26,7 @@ from .utils import (
load_shard_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict_shards,
search_tp_partition_dim,
@ -204,6 +205,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
@ -219,9 +221,9 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
@ -229,7 +231,8 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors)
use_safetensors=use_safetensors,
use_pp_format=True)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
@ -251,6 +254,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
final_index_file.append_weight_map(weight, weight_filename)
final_index_file.write_index_file(final_index_file_path)
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
@ -423,15 +427,16 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving)
is_master=control_saving,
use_pp_format=True)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."

View File

@ -9,12 +9,12 @@ from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
@ -228,7 +228,8 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False) -> int:
use_safetensors: bool = False,
use_pp_format: bool = False) -> int:
'''
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
@ -236,14 +237,16 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is master.
use_safetensors (bool): Whether to use safetensors to save checkpoint.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns:
int: the total size of shards
'''
total_size = 0
shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
if not is_master:
@ -257,8 +260,12 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
shard_filenames.append(shard_file)
del shard
# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size
@ -335,6 +342,66 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
torch.save(param_groups, group_file_path)
def clean_folder(checkpoint_path: str,
weights_name: str,
shard_filenames: List[str],
is_master: bool = True,
use_pp_format: bool = False):
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
Args:
checkpoint_path (str): Path to the checkpoint directory.
weights_name (str): Decides the prefix of filenames of weight shards.
shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
is_master (bool, optional): Whether current rank is main process. Defaults to True.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
"""
if is_master:
for filename in os.listdir(checkpoint_path):
full_filename = os.path.join(checkpoint_path, filename)
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
if not use_pp_format:
reg = re.compile(r"(.*?)-\d{5}")
else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
os.remove(full_filename)
def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
"""
Save config.json/generation_config.json if model is a Huggingface pretrained model.
This method can only be called when a model is saved in a sharded way.
Args:
model (nn.Module): The model whose config should be saved if it's a huggingface model.
checkpoint_path (str): Path to the checkpoint directory.
is_master (bool): Whether current rank is main process.
"""
if not isinstance(model, PreTrainedModel):
return
model = unwrap_huggingface_model(model)
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
dtype = get_parameter_dtype(model)
model.config.torch_dtype = str(dtype).split(".")[1]
# Attach architecture to the config
model.config.architectures = [model.__class__.__name__]
# Save the config
if is_master:
model.config.save_pretrained(checkpoint_path)
if model.can_generate():
model.generation_config.save_pretrained(checkpoint_path)
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
@ -709,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int):
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file

View File

@ -0,0 +1,129 @@
import pytest
import torch
import torch.distributed as dist
from torch.optim import Adam
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
def exam_from_pretrained(model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
test_config,
shard=True,
size_per_shard=32):
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
def _preprocess_data(data):
if booster.plugin.stage_manager is not None:
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
return iter([data])
else:
return {k: v.cuda() for k, v in data.items()}
model = model_fn()
optimizer = Adam((model.parameters()), lr=0.001)
criterion = loss_fn
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(_preprocess_data(data),
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
else:
output = model(**_preprocess_data(data))
loss = criterion(output)
optimizer.backward(loss)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()
new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path)
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
Randomizer.reset_index()
torch.cuda.empty_cache()
@clear_cache_before_run()
@parameterize('test_config', [{
'tp_size': 4,
'pp_size': 1,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 2,
'pp_size': 1,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_test()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_huggingface_compatibility(world_size):
spawn(run_dist, world_size)