From 307894f74dd63d71f4b95272fe149ca607e2aafa Mon Sep 17 00:00:00 2001 From: jiangmingyan <1829166702@qq.com> Date: Fri, 5 May 2023 14:37:21 +0800 Subject: [PATCH] [booster] gemini plugin support shard checkpoint (#3610) * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint --------- Co-authored-by: luchen Co-authored-by: luchen --- colossalai/booster/plugin/gemini_plugin.py | 44 +++++++++ .../checkpoint_io/checkpoint_io_base.py | 2 +- .../checkpoint_io/general_checkpoint_io.py | 63 ++++++------ colossalai/checkpoint_io/index_file.py | 16 ++- colossalai/checkpoint_io/utils.py | 86 ++++++++-------- colossalai/zero/gemini/gemini_ddp.py | 51 +++++++--- pytest.ini | 2 +- .../test_general_checkpoint_io.py | 99 ++++++++++++++++++- .../test_zeroddp_state_dict_shard.py | 3 +- 9 files changed, 269 insertions(+), 97 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index deda00d8a..dfdd7be26 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,6 +1,9 @@ import random import warnings from typing import Callable, List, Optional, Tuple, Union +from pathlib import Path +import os +import logging import numpy as np import torch @@ -20,6 +23,13 @@ from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero.gemini.memory_tracer import MemStats +from colossalai.checkpoint_io.utils import ( + get_base_filenames, + get_shard_filename + ) + +from colossalai.checkpoint_io import CheckpointIndexFile + from .plugin_base import Plugin __all__ = ['GeminiPlugin'] @@ -62,6 +72,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): + """ + Save sharded model + """ + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) + weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + total_size = 0 + index_file = CheckpointIndexFile(checkpoint_path) + for idx, shard_pair in enumerate(state_dict_shard): + if not self.coordinator.is_master(): + continue + shard = shard_pair[0] + shard_file = get_shard_filename(weights_name, idx) + total_size = total_size + shard_pair[1] + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors) + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info( + f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + + def load_sharded_model(self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + """ + load shard model, load model from multiple files + """ + return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) class GeminiModel(ModelWrapper): diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index cb853559c..9cf344ecc 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -86,7 +86,7 @@ class CheckpointIO(ABC): # the existence of index file means it is a sharded checkpoint ckpt_path = Path(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint) - + # return the origin model instead of the unwrapped model origin_model = model diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index bf584f45d..96a883fdb 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -1,12 +1,12 @@ from pathlib import Path +from functools import reduce import torch.nn as nn from torch.optim import Optimizer import logging import os -import json import gc -from typing import Optional +from typing import Optional, Iterator, OrderedDict, Tuple from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile @@ -18,10 +18,9 @@ from .utils import ( shard_checkpoint, load_shard_state_dict, load_state_dict_into_model, - add_variant + get_shard_filename, + get_base_filenames ) -from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME - __all__ = ['GeneralCheckpointIO'] @@ -85,30 +84,32 @@ class GeneralCheckpointIO(CheckpointIO): # shard checkpoint state_dict = model.state_dict() - weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME - weights_name = add_variant(weights_name, variant) - shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) - - # Save the model - for shard_file, shard in shards.items(): + state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size) + + weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + total_size = 0 + index_file = CheckpointIndexFile(checkpoint_path) + for idx, shard_pair in enumerate(state_dict_shard): + shard = shard_pair[0] + shard_file = get_shard_filename(weights_name, idx) + total_size = total_size + shard_pair[1] + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors) - - # save index file - save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - - save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant)) - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) logging.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." ) - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, + use_safetensors: bool = False, load_sub_module: bool = True): """ load shard model, load model from multiple files """ @@ -122,17 +123,21 @@ class GeneralCheckpointIO(CheckpointIO): # read checkpoint index file ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() - missing_keys = ckpt_index_file.get_all_param_names() + missing_keys = [] for shard_file in checkpoint_files: state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) - load_state_dict_into_model(model, state_dict, missing_keys, strict) + load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module) del state_dict gc.collect() - if strict and len(missing_keys) > 0: - error_msgs = 'Missing key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in missing_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + if strict: + remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) + if len(remain_keys) > 0: + error_msgs = 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + + diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 89224787a..15a6d09f3 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -1,6 +1,8 @@ import json from pathlib import Path from typing import Any, List, Union +import os +import json from .utils import is_dtensor_checkpoint @@ -18,8 +20,8 @@ class CheckpointIndexFile: >>> index.export('new_index.json') """ - def __init__(self) -> None: - self.root_path = None + def __init__(self, root_path=None) -> None: + self.root_path = root_path self.metadata: dict = dict() self.weight_map: dict = dict() @@ -154,3 +156,13 @@ class CheckpointIndexFile: Get all the weight keys. """ return list(self.weight_map.keys()) + + def write_index_file(self, save_index_file): + """ + Wriete index file. + """ + save_index_file = os.path.join(self.root_path, save_index_file) + index = {"metadata": self.metadata, "weight_map": self.weight_map} + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 37d22d08d..16e41631f 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -2,7 +2,7 @@ from pathlib import Path import torch import torch.nn as nn -from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple +from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator from colossalai.tensor.d_tensor.d_tensor import DTensor import re @@ -77,55 +77,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # ====================================== # Helper functions for saving shard file # ====================================== -def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME): +def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. """ - sharded_state_dicts = [] current_block = {} current_block_size = 0 - total_size = 0 for key, weight in state_dict.items(): + ret_block = None + ret_block_size = 0 if type(weight) != DTensor: weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. if current_block_size + weight_size > max_shard_size: - sharded_state_dicts.append(current_block) + ret_block = current_block + ret_block_size = current_block_size current_block = {} current_block_size = 0 - current_block[key] = weight current_block_size += weight_size - total_size += weight_size + + if ret_block != None: + yield ret_block, ret_block_size - # Add the last block - sharded_state_dicts.append(current_block) + yield current_block, current_block_size - # If we only have one shard, we return it - if len(sharded_state_dicts) == 1: - return {weights_name: sharded_state_dicts[0]}, None - - # Otherwise, let's build the index - weight_map = {} - shards = {} - - for idx, shard in enumerate(sharded_state_dicts): - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") - shard_file = shard_file.replace( - ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" - ) - shards[shard_file] = shard - for key in shard.keys(): - weight_map[key] = shard_file - - # Add the metadata - metadata = {"total_size": total_size} - index = {"metadata": metadata, "weight_map": weight_map} - return shards, index def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): """ @@ -146,7 +126,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): else: return torch.load(checkpoint_file) -def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False): +def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. @@ -167,29 +147,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi if metadata is not None: state_dict._metadata = metadata - def load(module: nn.Module, state_dict, prefix=""): + def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict if len([key for key in state_dict if key.startswith(prefix)]) > 0: module._load_from_state_dict(*args) + if load_sub_module: + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") - for name, child in module._modules.items(): - if child is not None: - load(child, state_dict, prefix + name + ".") - - load(model, state_dict, "") + load(model, state_dict, "", load_sub_module) del load - # deal with missing key - if len(missing_keys) > 0: - deleted_keys = [] - for key in missing_keys: - if key not in sub_missing_keys: - deleted_keys.append(key) - for key in deleted_keys: - missing_keys.remove(key) + missing_keys = missing_keys.append(sub_missing_keys) if strict: if len(unexpected_keys) > 0: @@ -417,3 +390,24 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str: weights_name = ".".join(splits) return weights_name + + +def get_base_filenames(variant: str=None, use_safetensors: bool=False): + """ + generate base weight filenames + """ + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_variant(weights_name, variant) + + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = add_variant(save_index_file, variant) + + return weights_name, save_index_file + +def get_shard_filename(weights_name: str, idx: int): + """ + get shard file name + """ + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") + shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") + return shard_file \ No newline at end of file diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 8a001b114..878c25be7 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -2,7 +2,7 @@ import itertools from collections import OrderedDict from contextlib import nullcontext from functools import partial -from typing import Dict, Iterator, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union, Tuple, Set import torch import torch.distributed as dist @@ -96,8 +96,35 @@ class ZeroDDP(ColoDDP): param_name = m_name + '.' + p_name if m_name else p_name self.name2param[param_name] = p_var super().__init__(module, process_group=ColoProcessGroup()) + self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module) self._cast_buffers() + def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True): + + r""" + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + """ + + if memo is None: + memo = set() + self_non_persistent_set = set() + if module not in memo: + if remove_duplicate: + memo.add(module) + self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) + for name, sub_module in module._modules.items(): + if sub_module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate) + self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) + return self_non_persistent_set + + def _post_forward(self): """This function is only triggered for inference. """ @@ -604,7 +631,7 @@ class ZeroDDP(ColoDDP): keep_vars: bool = False, max_shard_size: int = 1024, only_rank_0: bool = True, - dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]: + dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Both parameters and persistent buffers (e.g. running averages) are included. @@ -644,9 +671,9 @@ class ZeroDDP(ColoDDP): gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) - block = sharder.append(prefix + name, gathered_param) + block, block_size = sharder.append(prefix + name, gathered_param) if block is not None: - yield block + yield block, block_size del fp16_to_fp32 del gathered_param_buffer @@ -655,19 +682,19 @@ class ZeroDDP(ColoDDP): for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block = sharder.append(prefix + name, buffer) + block, block_size = sharder.append(prefix + name, buffer) if block is not None: - yield block + yield block, block_size # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = self.get_extra_state() - block = sharder.append(extra_state_key, extra_state) + block, block_size = sharder.append(extra_state_key, extra_state) if block is not None: - yield block + yield block, block_size - yield sharder.current_block + yield sharder.current_block, sharder.current_block_size class _StateDictSharder: @@ -677,16 +704,18 @@ class _StateDictSharder: self.current_block = OrderedDict() self.current_block_size = 0 - def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]: + def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: tensor_size = calculate_tensor_size(tensor) ret_block = None + ret_block_size = 0 if self.current_block_size + tensor_size > self.max_shard_size: ret_block = self.current_block + ret_block_size = self.current_block_size self.current_block = OrderedDict() self.current_block_size = 0 self.current_block[name] = tensor self.current_block_size += tensor_size - return ret_block + return ret_block, ret_block_size class GeminiDDP(ZeroDDP): diff --git a/pytest.ini b/pytest.ini index ac31ace4b..01e5cd217 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,4 +3,4 @@ markers = cpu: tests which can run on CPU gpu: tests which requires a single GPU dist: tests which are run in a multi-GPU or multi-machine environment - experiment: tests for experimental features \ No newline at end of file + experiment: tests for experimental features diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index ca5ce1005..752ca706b 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,16 +1,21 @@ import tempfile import pytest import torch -import logging from torch.optim import Adam from torchvision.models import resnet18 -from pathlib import Path -import os -import subprocess from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO from colossalai.testing import clear_cache_before_run, parameterize +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test.registry import non_distributed_component_funcs + # ======== # Note: # 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now @@ -83,7 +88,6 @@ def test_sharded_checkpoint(use_safetensors: bool): suffix = ".bin" WEIGHTS_INDEX_NAME = "model.bin.index.json" - # model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix) model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() @@ -104,6 +108,87 @@ def test_sharded_checkpoint(use_safetensors: bool): recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['bert']) +@parameterize('use_safetensors', [True, False]) +def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool): + from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification + + model_ckpt_dir = tempfile.TemporaryDirectory() + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, *_ = get_components_func() + + with ColoInitContext(device=get_current_device()): + bert_model = model_builder() + bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name) + config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + bert_model = ZeroDDP(bert_model, gemini_manager) + bert_model.train() + + ckpt_io = GeminiCheckpointIO() + if ckpt_io.coordinator.is_master(): + model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 + ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors) + new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name) + recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict()) + + model_ckpt_dir.cleanup() + + + +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['gpt2', 'bert']) +@parameterize('use_safetensors', [True, False]) +def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, *_ = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + new_model = model_builder() + + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager) + model.train() + + new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100) + new_chunk_manager = ChunkManager(new_config_dict) + new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) + new_model = ZeroDDP(new_model, new_gemini_manager) + + model_ckpt_dir = tempfile.TemporaryDirectory() + + ckpt_io = GeminiCheckpointIO() + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors) + + # load model + if ckpt_io.coordinator.is_master(): + ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True) + model_dict = model.state_dict(only_rank_0=True) + new_model_dict = new_model.state_dict(only_rank_0=True) + recursive_check(model_dict, new_model_dict) + + model_ckpt_dir.cleanup() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + hf_load_colossalai_checkpoint() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4, 4]) +@rerun_if_address_is_in_use() +def test_gemini_ckpIO(world_size): + spawn(run_dist, world_size) + # do recursive check for the optimizer state dict # if the value is a dict, compare its values @@ -117,10 +202,14 @@ def recursive_check(d1, d2): elif isinstance(v, list): for i in range(len(v)): if isinstance(v[i], torch.Tensor): + v[i] = v[i].to("cpu") + d2[k][i] = d2[k][i].to("cpu") assert torch.equal(v[i], d2[k][i]) else: assert v[i] == d2[k][i] elif isinstance(v, torch.Tensor): + v = v.to("cpu") + d2[k] = d2[k].to("cpu") assert torch.equal(v, d2[k]) else: assert v == d2[k] diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py index 96c26a1de..ad7d3a5a4 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py @@ -31,14 +31,13 @@ def exam_state_dict(placement_policy, model_name: str): zero_dict = model.state_dict(only_rank_0=False) accumulated_keys = set() # ensure number of shards > 1 - for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): + for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): for key, value in shard.items(): assert key not in accumulated_keys, f"key `{key}` is duplicated." accumulated_keys.add(key) assert key in zero_dict, f"{key} not in ZeRO dictionary." assert torch.equal(value, zero_dict[key]), f"{key} not equal." - def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')