diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index c53e0f44c..9d0475ed0 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -1,4 +1,11 @@ -from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal +from .comparison import ( + assert_close, + assert_close_loose, + assert_equal, + assert_equal_in_group, + assert_not_equal, + check_state_dict_equal, +) from .pytest_wrapper import run_on_environment_flag from .utils import ( clear_cache_before_run, @@ -13,5 +20,5 @@ from .utils import ( __all__ = [ 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', - 'clear_cache_before_run', 'run_on_environment_flag' + 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal' ] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index e00d0da16..faf61638d 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -1,3 +1,5 @@ +from typing import OrderedDict + import torch import torch.distributed as dist from torch import Tensor @@ -28,3 +30,25 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): a = tensor_list[i] b = tensor_list[i + 1] assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}' + + +def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): + for k, v in d1.items(): + if isinstance(v, dict): + check_state_dict_equal(v, d2[k]) + elif isinstance(v, list): + for i in range(len(v)): + if isinstance(v[i], torch.Tensor): + if not ignore_device: + v[i] = v[i].to("cpu") + d2[k][i] = d2[k][i].to("cpu") + assert torch.equal(v[i], d2[k][i]) + else: + assert v[i] == d2[k][i] + elif isinstance(v, torch.Tensor): + if not ignore_device: + v = v.to("cpu") + d2[k] = d2[k].to("cpu") + assert torch.equal(v, d2[k]) + else: + assert v == d2[k] diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py new file mode 100644 index 000000000..1e5a2e1c4 --- /dev/null +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -0,0 +1,98 @@ +import tempfile + +import pytest +import torch + +import colossalai +from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO +from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['bert']) +@parameterize('use_safetensors', [True, False]) +def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): + from transformers import BertForSequenceClassification + + model_ckpt_dir = tempfile.TemporaryDirectory() + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, *_ = get_components_func() + with ColoInitContext(device=(get_current_device())): + bert_model = model_builder() + bert_model.config.save_pretrained(save_directory=(model_ckpt_dir.name)) + + config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + bert_model = ZeroDDP(bert_model, gemini_manager) + bert_model.train() + + ckpt_io = GeminiCheckpointIO() + if ckpt_io.coordinator.is_master(): + model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 + ckpt_io.save_model(bert_model, (model_ckpt_dir.name), + True, + True, + '', (model_size / 3), + use_safetensors=use_safetensors) + new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name) + check_state_dict_equal(bert_model.state_dict(only_rank_0=True, dtype=(torch.float32)), + new_bert_model.state_dict(), False) + model_ckpt_dir.cleanup() + + +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['gpt2', 'bert']) +@parameterize('use_safetensors', [True, False]) +def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, *_ = get_components_func() + with ColoInitContext(device=(get_current_device())): + model = model_builder() + new_model = model_builder() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager) + + model.train() + #new model + new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100) + new_chunk_manager = ChunkManager(new_config_dict) + new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) + new_model = ZeroDDP(new_model, new_gemini_manager) + + model_ckpt_dir = tempfile.TemporaryDirectory() + ckpt_io = GeminiCheckpointIO() + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + ckpt_io.save_model(model, (model_ckpt_dir.name), + True, + True, + 'epoch', (model_size / 3), + use_safetensors=use_safetensors) + + if ckpt_io.coordinator.is_master(): + ckpt_io.load_model(new_model, (model_ckpt_dir.name), strict=True) + model_dict = model.state_dict(only_rank_0=True) + new_model_dict = new_model.state_dict(only_rank_0=True) + check_state_dict_equal(model_dict, new_model_dict, False) + model_ckpt_dir.cleanup() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + exam_state_dict_with_origin() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4, 4]) +@rerun_if_address_is_in_use() +def test_gemini_ckpIO(world_size): + spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 752ca706b..9e973bb23 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,20 +1,13 @@ import tempfile + import pytest import torch from torch.optim import Adam from torchvision.models import resnet18 -from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO -from colossalai.testing import clear_cache_before_run, parameterize - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero import ColoInitContext, ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager -from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.testing import check_state_dict_equal, clear_cache_before_run, parameterize # ======== # Note: @@ -61,10 +54,10 @@ def test_unsharded_checkpoint(use_safetensors: bool): ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) - # check for model and optimizer state dict recursively - recursive_check(model.state_dict(), new_model.state_dict()) - recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) + @pytest.mark.parametrize('use_safetensors', [True, False]) def test_sharded_checkpoint(use_safetensors: bool): @@ -87,7 +80,7 @@ def test_sharded_checkpoint(use_safetensors: bool): else: suffix = ".bin" WEIGHTS_INDEX_NAME = "model.bin.index.json" - + model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() @@ -96,7 +89,7 @@ def test_sharded_checkpoint(use_safetensors: bool): ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False) - + # create new model new_model = resnet18() new_optimizer = Adam(new_model.parameters(), lr=0.001) @@ -105,111 +98,5 @@ def test_sharded_checkpoint(use_safetensors: bool): ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) # check for model and optimizer state dict recursively - recursive_check(model.state_dict(), new_model.state_dict()) - recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) - -@parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('model_name', ['bert']) -@parameterize('use_safetensors', [True, False]) -def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool): - from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification - - model_ckpt_dir = tempfile.TemporaryDirectory() - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, *_ = get_components_func() - - with ColoInitContext(device=get_current_device()): - bert_model = model_builder() - bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name) - config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100) - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - bert_model = ZeroDDP(bert_model, gemini_manager) - bert_model.train() - - ckpt_io = GeminiCheckpointIO() - if ckpt_io.coordinator.is_master(): - model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 - ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors) - new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name) - recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict()) - - model_ckpt_dir.cleanup() - - - -@parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('model_name', ['gpt2', 'bert']) -@parameterize('use_safetensors', [True, False]) -def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, *_ = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - new_model = model_builder() - - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager) - model.train() - - new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100) - new_chunk_manager = ChunkManager(new_config_dict) - new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) - new_model = ZeroDDP(new_model, new_gemini_manager) - - model_ckpt_dir = tempfile.TemporaryDirectory() - - ckpt_io = GeminiCheckpointIO() - model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 - ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors) - - # load model - if ckpt_io.coordinator.is_master(): - ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True) - model_dict = model.state_dict(only_rank_0=True) - new_model_dict = new_model.state_dict(only_rank_0=True) - recursive_check(model_dict, new_model_dict) - - model_ckpt_dir.cleanup() - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_state_dict() - hf_load_colossalai_checkpoint() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [4, 4]) -@rerun_if_address_is_in_use() -def test_gemini_ckpIO(world_size): - spawn(run_dist, world_size) - - -# do recursive check for the optimizer state dict -# if the value is a dict, compare its values -# if the value is a list, comapre all elements one-by-one -# if the value is a torch.Tensor, use torch.equal -# otherwise use assertEqual -def recursive_check(d1, d2): - for k, v in d1.items(): - if isinstance(v, dict): - recursive_check(v, d2[k]) - elif isinstance(v, list): - for i in range(len(v)): - if isinstance(v[i], torch.Tensor): - v[i] = v[i].to("cpu") - d2[k][i] = d2[k][i].to("cpu") - assert torch.equal(v[i], d2[k][i]) - else: - assert v[i] == d2[k][i] - elif isinstance(v, torch.Tensor): - v = v.to("cpu") - d2[k] = d2[k].to("cpu") - assert torch.equal(v, d2[k]) - else: - assert v == d2[k] + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py new file mode 100644 index 000000000..217a950d8 --- /dev/null +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -0,0 +1,57 @@ +import tempfile + +import pytest +import torch +from torchvision.models import resnet18 + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroCheckpointIO +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) + + +@clear_cache_before_run() +@parameterize('stage', [2]) +def check_low_level_zero_checkpointIO(stage: int): + plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32) + booster = Booster(plugin=plugin) + model = resnet18() + criterion = lambda x: x.mean() + optimizer = HybridAdam((model.parameters()), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + x = torch.randn(4, 3, 224, 224) + x = x.to('cuda') + output = model(x) + loss = criterion(output) + booster.backward(loss, optimizer) + optimizer.step() + + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + ckpt_io = LowLevelZeroCheckpointIO() + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) + + if ckpt_io.coordinator.is_master(): + new_model = resnet18() + new_optimizer = HybridAdam((new_model.parameters()), lr=0.001) + _, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer) + ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') + check_low_level_zero_checkpointIO() + + +@rerun_if_address_is_in_use() +def test_low_level_zero_checkpointIO(): + spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py new file mode 100644 index 000000000..9128f8c0f --- /dev/null +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -0,0 +1,63 @@ +import tempfile + +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import SGD +from torchvision.models import resnet18 + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPCheckpointIO +from colossalai.interface import OptimizerWrapper +from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn + + +def check_torch_ddp_checkpointIO(): + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + model = resnet18() + criterion = lambda x: x.mean() + optimizer = SGD((model.parameters()), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler) + + assert isinstance(model.module, DDP) + assert isinstance(optimizer, OptimizerWrapper) + + x = torch.randn(4, 3, 224, 224) + x = x.to('cuda') + output = model(x) + loss = criterion(output) + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + scheduler.step() + + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile() + ckpt_io = TorchDDPCheckpointIO() + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) + ckpt_io.save_lr_scheduler(scheduler, lr_scheduler_ckpt_tempfile.name) + + if ckpt_io.coordinator.is_master(): + new_model = resnet18() + new_optimizer = SGD((new_model.parameters()), lr=0.001) + new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1) + _, new_optimizer, _, _, new_scheduler = booster.boost(new_model, new_optimizer, lr_scheduler=new_scheduler) + + ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + + ckpt_io.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_tempfile.name) + check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') + check_torch_ddp_checkpointIO() + + +@rerun_if_address_is_in_use() +def test_torch_ddp_checkpointIO(): + spawn(run_dist, 2)