mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support from_pretrained when loading model with HybridParallelPlugin (#4575)
* hybrid plugin support huggingface from_pretrained * add huggingface compatibility tests * add folder cleaning * fix bugspull/4619/head
parent
c9625dbb63
commit
38ccb8b1a3
|
@ -208,7 +208,7 @@ jobs:
|
|||
|
||||
- name: Execute Unit Testing
|
||||
run: |
|
||||
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-cov=. --durations=10 tests/
|
||||
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
NCCL_SHM_DISABLE: 1
|
||||
|
|
|
@ -141,10 +141,10 @@ def get_param_info(optim: Optimizer):
|
|||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
params = set(model.parameters())
|
||||
model_params = set(model.parameters())
|
||||
new_param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [p for p in group['params'] if p in params]
|
||||
params = [p for p in group['params'] if p in model_params]
|
||||
new_param_groups.append({**group, 'params': params})
|
||||
optim.__setstate__({'param_groups': new_param_groups})
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ from .utils import (
|
|||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
|
@ -204,6 +205,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
|
@ -219,9 +221,9 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
|
||||
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
|
||||
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
|
@ -229,7 +231,8 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors)
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True)
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
|
@ -251,6 +254,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
final_index_file.append_weight_map(weight, weight_filename)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
save_config_file(model, checkpoint)
|
||||
rmtree(tmp_index_file_folder)
|
||||
if self.verbose:
|
||||
logging.info(f"The model is split into checkpoint shards. "
|
||||
|
@ -423,15 +427,16 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json")
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving)
|
||||
is_master=control_saving,
|
||||
use_pp_format=True)
|
||||
|
||||
if control_saving:
|
||||
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
|
||||
|
|
|
@ -9,12 +9,12 @@ from pathlib import Path
|
|||
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
|
||||
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
|
@ -228,7 +228,8 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
|||
index_file: "CheckpointIndexFile",
|
||||
base_filename: str,
|
||||
is_master: bool,
|
||||
use_safetensors: bool = False) -> int:
|
||||
use_safetensors: bool = False,
|
||||
use_pp_format: bool = False) -> int:
|
||||
'''
|
||||
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
|
||||
Args:
|
||||
|
@ -236,14 +237,16 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
|||
checkpoint (str): The path of checkpoint directory as string.
|
||||
index_file (CheckpointIndexFile): The index file object to be updated.
|
||||
base_filename (str): Decides the prefix of filenames of shards.
|
||||
is_master (bool): Whether current rank is master.
|
||||
use_safetensors (bool): Whether to use safetensors to save checkpoint.
|
||||
is_master (bool): Whether current rank is main process.
|
||||
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
|
||||
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
|
||||
|
||||
Returns:
|
||||
int: the total size of shards
|
||||
'''
|
||||
|
||||
total_size = 0
|
||||
shard_filenames = []
|
||||
for idx, shard_pair in enumerate(sharded_state_dict):
|
||||
shard, current_size = shard_pair
|
||||
if not is_master:
|
||||
|
@ -257,8 +260,12 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
|
|||
|
||||
# Only save on master rank.
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
|
||||
shard_filenames.append(shard_file)
|
||||
del shard
|
||||
|
||||
# Clean folder, deleted unneeded files.
|
||||
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
|
@ -335,6 +342,66 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None:
|
|||
torch.save(param_groups, group_file_path)
|
||||
|
||||
|
||||
def clean_folder(checkpoint_path: str,
|
||||
weights_name: str,
|
||||
shard_filenames: List[str],
|
||||
is_master: bool = True,
|
||||
use_pp_format: bool = False):
|
||||
"""
|
||||
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Path to the checkpoint directory.
|
||||
weights_name (str): Decides the prefix of filenames of weight shards.
|
||||
shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
|
||||
is_master (bool, optional): Whether current rank is main process. Defaults to True.
|
||||
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
|
||||
|
||||
"""
|
||||
if is_master:
|
||||
for filename in os.listdir(checkpoint_path):
|
||||
full_filename = os.path.join(checkpoint_path, filename)
|
||||
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
if not use_pp_format:
|
||||
reg = re.compile(r"(.*?)-\d{5}")
|
||||
else:
|
||||
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
|
||||
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
|
||||
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
|
||||
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
|
||||
os.remove(full_filename)
|
||||
|
||||
|
||||
def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
|
||||
"""
|
||||
Save config.json/generation_config.json if model is a Huggingface pretrained model.
|
||||
This method can only be called when a model is saved in a sharded way.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model whose config should be saved if it's a huggingface model.
|
||||
checkpoint_path (str): Path to the checkpoint directory.
|
||||
is_master (bool): Whether current rank is main process.
|
||||
"""
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
return
|
||||
|
||||
model = unwrap_huggingface_model(model)
|
||||
|
||||
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
|
||||
dtype = get_parameter_dtype(model)
|
||||
model.config.torch_dtype = str(dtype).split(".")[1]
|
||||
|
||||
# Attach architecture to the config
|
||||
model.config.architectures = [model.__class__.__name__]
|
||||
|
||||
# Save the config
|
||||
if is_master:
|
||||
model.config.save_pretrained(checkpoint_path)
|
||||
if model.can_generate():
|
||||
model.generation_config.save_pretrained(checkpoint_path)
|
||||
|
||||
|
||||
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
|
||||
"""
|
||||
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
|
||||
|
@ -709,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int):
|
|||
get shard file name
|
||||
"""
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
|
||||
return shard_file
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Adam
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import (
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def exam_from_pretrained(model_fn,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
loss_fn,
|
||||
test_config,
|
||||
shard=True,
|
||||
size_per_shard=32):
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
outputs = output_transform_fn(outputs)
|
||||
loss = criterion(outputs)
|
||||
return loss
|
||||
|
||||
def _preprocess_data(data):
|
||||
if booster.plugin.stage_manager is not None:
|
||||
for k, v in data.items():
|
||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 4
|
||||
data[k] = v.to('cuda').repeat(*new_shape)
|
||||
return iter([data])
|
||||
else:
|
||||
return {k: v.cuda() for k, v in data.items()}
|
||||
|
||||
model = model_fn()
|
||||
optimizer = Adam((model.parameters()), lr=0.001)
|
||||
criterion = loss_fn
|
||||
plugin = HybridParallelPlugin(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
booster.execute_pipeline(_preprocess_data(data),
|
||||
model,
|
||||
_criterion,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=False)
|
||||
else:
|
||||
output = model(**_preprocess_data(data))
|
||||
loss = criterion(output)
|
||||
optimizer.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
|
||||
dist.barrier()
|
||||
|
||||
new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path)
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}, {
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'zero_stage': 1,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_huggingface_compatibility(world_size):
|
||||
spawn(run_dist, world_size)
|
Loading…
Reference in New Issue