diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 0f5ba6e9a..8489a8f29 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -18,6 +18,7 @@ from colossalai.checkpoint_io.utils import ( get_optimizer_base_filenames, get_shard_filename, load_shard_state_dict, + save_config_file, save_state_dict, save_state_dict_shards, ) @@ -111,6 +112,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model.module, checkpoint_path) logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 83e4bdcc8..09362d145 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -23,6 +23,7 @@ from .utils import ( load_state_dict, load_state_dict_into_model, load_states_into_optimizer, + save_config_file, save_param_groups, save_state_dict, save_state_dict_shards, @@ -185,6 +186,7 @@ class GeneralCheckpointIO(CheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint_path, is_master=True) logging.info(f"The model is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py b/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py deleted file mode 100644 index df907605d..000000000 --- a/tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from torch.optim import Adam -from utils import shared_tempdir - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin -from colossalai.shardformer.layer.utils import Randomizer -from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.testing import ( - check_state_dict_equal, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo - - -def exam_from_pretrained(model_fn, - data_gen_fn, - output_transform_fn, - loss_fn, - test_config, - shard=True, - size_per_shard=32): - - def _criterion(outputs, inputs): - outputs = output_transform_fn(outputs) - loss = criterion(outputs) - return loss - - def _preprocess_data(data): - if booster.plugin.stage_manager is not None: - for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) - return iter([data]) - else: - return {k: v.cuda() for k, v in data.items()} - - model = model_fn() - optimizer = Adam((model.parameters()), lr=0.001) - criterion = loss_fn - plugin = HybridParallelPlugin(**test_config) - booster = Booster(plugin=plugin) - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - data = data_gen_fn() - model.train() - if booster.plugin.stage_manager is not None: - booster.execute_pipeline(_preprocess_data(data), - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) - else: - output = model(**_preprocess_data(data)) - loss = criterion(output) - optimizer.backward(loss) - - optimizer.step() - - with shared_tempdir() as tempdir: - - model_ckpt_path = f"{tempdir}/model" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - dist.barrier() - - new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path) - new_optimizer = Adam(new_model.parameters(), lr=1e-3) - new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - - check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) - - Randomizer.reset_index() - torch.cuda.empty_cache() - - -@clear_cache_before_run() -@parameterize('test_config', [{ - 'tp_size': 4, - 'pp_size': 1, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 2, - 'pp_size': 1, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) -def run_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - exam_from_pretrained(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) - clear_layout_converter() - torch.cuda.empty_cache() - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_test() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) -@rerun_if_address_is_in_use() -def test_huggingface_compatibility(world_size): - spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py new file mode 100644 index 000000000..3f3b0392a --- /dev/null +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -0,0 +1,83 @@ +import os + +import pytest +import torch +import torch.distributed as dist +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + + +@clear_cache_before_run() +@parameterize('model_name', ['transformers_gpt']) +@parameterize('plugin_type', ['ddp', 'zero', 'gemini']) +def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): + (model_fn, data_gen_fn, output_transform_fn, loss_fn, + _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = loss_fn + + if plugin_type == 'ddp': + plugin = TorchDDPPlugin() + elif plugin_type == 'zero': + plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) + elif plugin_type == 'gemini': + plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32) + else: + raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") + + booster = Booster(plugin=plugin) + + model = model_fn().cuda() + model_huggingface_cls = model.__class__ + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + + model_ckpt_path = f"{tempdir}/model" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + new_model = model_huggingface_cls.from_pretrained(model_ckpt_path) + new_model = new_model.cuda() + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + + if plugin_type == 'gemini': + check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), + new_model.unwrap().state_dict(only_rank_0=False), False) + else: + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + dist.barrier() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_from_pretrained() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_huggingface_compatibility(world_size): + spawn(run_dist, world_size)