[booster] gemini plugin support shard checkpoint (#3610)

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

---------

Co-authored-by: luchen <luchen@luchendeMBP.lan>
Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
pull/3684/head^2
jiangmingyan 2023-05-05 14:37:21 +08:00 committed by GitHub
parent 0f785cb1f3
commit 307894f74d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 268 additions and 96 deletions

View File

@ -1,6 +1,9 @@
import random
import warnings
from typing import Callable, List, Optional, Tuple, Union
from pathlib import Path
import os
import logging
import numpy as np
import torch
@ -20,6 +23,13 @@ from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.memory_tracer import MemStats
from colossalai.checkpoint_io.utils import (
get_base_filenames,
get_shard_filename
)
from colossalai.checkpoint_io import CheckpointIndexFile
from .plugin_base import Plugin
__all__ = ['GeminiPlugin']
@ -62,6 +72,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
"""
Save sharded model
"""
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
if not self.coordinator.is_master():
continue
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
"""
load shard model, load model from multiple files
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
class GeminiModel(ModelWrapper):

View File

@ -86,7 +86,7 @@ class CheckpointIO(ABC):
# the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint)
index_file_exists, index_file_path = has_index_file(checkpoint)
# return the origin model instead of the unwrapped model
origin_model = model

View File

@ -1,12 +1,12 @@
from pathlib import Path
from functools import reduce
import torch.nn as nn
from torch.optim import Optimizer
import logging
import os
import json
import gc
from typing import Optional
from typing import Optional, Iterator, OrderedDict, Tuple
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
@ -18,10 +18,9 @@ from .utils import (
shard_checkpoint,
load_shard_state_dict,
load_state_dict_into_model,
add_variant
get_shard_filename,
get_base_filenames
)
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
__all__ = ['GeneralCheckpointIO']
@ -85,30 +84,32 @@ class GeneralCheckpointIO(CheckpointIO):
# shard checkpoint
state_dict = model.state_dict()
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
# Save the model
for shard_file, shard in shards.items():
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
# save index file
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
use_safetensors: bool = False, load_sub_module: bool = True):
"""
load shard model, load model from multiple files
"""
@ -122,17 +123,21 @@ class GeneralCheckpointIO(CheckpointIO):
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
missing_keys = ckpt_index_file.get_all_param_names()
missing_keys = []
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
load_state_dict_into_model(model, state_dict, missing_keys, strict)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
del state_dict
gc.collect()
if strict and len(missing_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))

View File

@ -1,6 +1,8 @@
import json
from pathlib import Path
from typing import Any, List, Union
import os
import json
from .utils import is_dtensor_checkpoint
@ -18,8 +20,8 @@ class CheckpointIndexFile:
>>> index.export('new_index.json')
"""
def __init__(self) -> None:
self.root_path = None
def __init__(self, root_path=None) -> None:
self.root_path = root_path
self.metadata: dict = dict()
self.weight_map: dict = dict()
@ -154,3 +156,13 @@ class CheckpointIndexFile:
Get all the weight keys.
"""
return list(self.weight_map.keys())
def write_index_file(self, save_index_file):
"""
Wriete index file.
"""
save_index_file = os.path.join(self.root_path, save_index_file)
index = {"metadata": self.metadata, "weight_map": self.weight_map}
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

View File

@ -2,7 +2,7 @@
from pathlib import Path
import torch
import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
from colossalai.tensor.d_tensor.d_tensor import DTensor
import re
@ -77,55 +77,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if type(weight) != DTensor:
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
total_size += weight_size
if ret_block != None:
yield ret_block, ret_block_size
# Add the last block
sharded_state_dicts.append(current_block)
yield current_block, current_block_size
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
"""
@ -146,7 +126,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
else:
return torch.load(checkpoint_file)
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
@ -167,29 +147,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
if metadata is not None:
state_dict._metadata = metadata
def load(module: nn.Module, state_dict, prefix=""):
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
if load_sub_module:
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model, state_dict, "")
load(model, state_dict, "", load_sub_module)
del load
# deal with missing key
if len(missing_keys) > 0:
deleted_keys = []
for key in missing_keys:
if key not in sub_missing_keys:
deleted_keys.append(key)
for key in deleted_keys:
missing_keys.remove(key)
missing_keys = missing_keys.append(sub_missing_keys)
if strict:
if len(unexpected_keys) > 0:
@ -417,3 +390,24 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
weights_name = ".".join(splits)
return weights_name
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
return weights_name, save_index_file
def get_shard_filename(weights_name: str, idx: int):
"""
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
return shard_file

View File

@ -2,7 +2,7 @@ import itertools
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
import torch
import torch.distributed as dist
@ -96,8 +96,35 @@ class ZeroDDP(ColoDDP):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
self._cast_buffers()
def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _post_forward(self):
"""This function is only triggered for inference.
"""
@ -604,7 +631,7 @@ class ZeroDDP(ColoDDP):
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True,
dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]:
dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Both parameters and persistent buffers (e.g. running averages) are included.
@ -644,9 +671,9 @@ class ZeroDDP(ColoDDP):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)
block = sharder.append(prefix + name, gathered_param)
block, block_size = sharder.append(prefix + name, gathered_param)
if block is not None:
yield block
yield block, block_size
del fp16_to_fp32
del gathered_param_buffer
@ -655,19 +682,19 @@ class ZeroDDP(ColoDDP):
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block = sharder.append(prefix + name, buffer)
block, block_size = sharder.append(prefix + name, buffer)
if block is not None:
yield block
yield block, block_size
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
block = sharder.append(extra_state_key, extra_state)
block, block_size = sharder.append(extra_state_key, extra_state)
if block is not None:
yield block
yield block, block_size
yield sharder.current_block
yield sharder.current_block, sharder.current_block_size
class _StateDictSharder:
@ -677,16 +704,18 @@ class _StateDictSharder:
self.current_block = OrderedDict()
self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]:
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
if self.current_block_size + tensor_size > self.max_shard_size:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block
return ret_block, ret_block_size
class GeminiDDP(ZeroDDP):

View File

@ -3,4 +3,4 @@ markers =
cpu: tests which can run on CPU
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features
experiment: tests for experimental features

View File

@ -1,16 +1,21 @@
import tempfile
import pytest
import torch
import logging
from torch.optim import Adam
from torchvision.models import resnet18
from pathlib import Path
import os
import subprocess
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
from colossalai.testing import clear_cache_before_run, parameterize
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs
# ========
# Note:
# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
@ -83,7 +88,6 @@ def test_sharded_checkpoint(use_safetensors: bool):
suffix = ".bin"
WEIGHTS_INDEX_NAME = "model.bin.index.json"
# model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix)
model_ckpt_dir = tempfile.TemporaryDirectory()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
@ -104,6 +108,87 @@ def test_sharded_checkpoint(use_safetensors: bool):
recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['bert'])
@parameterize('use_safetensors', [True, False])
def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool):
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification
model_ckpt_dir = tempfile.TemporaryDirectory()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()
with ColoInitContext(device=get_current_device()):
bert_model = model_builder()
bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name)
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager)
bert_model.train()
ckpt_io = GeminiCheckpointIO()
if ckpt_io.coordinator.is_master():
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors)
new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name)
recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict())
model_ckpt_dir.cleanup()
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
@parameterize('use_safetensors', [True, False])
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
new_model = model_builder()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model.train()
new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100)
new_chunk_manager = ChunkManager(new_config_dict)
new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager)
new_model = ZeroDDP(new_model, new_gemini_manager)
model_ckpt_dir = tempfile.TemporaryDirectory()
ckpt_io = GeminiCheckpointIO()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors)
# load model
if ckpt_io.coordinator.is_master():
ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True)
model_dict = model.state_dict(only_rank_0=True)
new_model_dict = new_model.state_dict(only_rank_0=True)
recursive_check(model_dict, new_model_dict)
model_ckpt_dir.cleanup()
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
hf_load_colossalai_checkpoint()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4, 4])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
# do recursive check for the optimizer state dict
# if the value is a dict, compare its values
@ -117,10 +202,14 @@ def recursive_check(d1, d2):
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], torch.Tensor):
v[i] = v[i].to("cpu")
d2[k][i] = d2[k][i].to("cpu")
assert torch.equal(v[i], d2[k][i])
else:
assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor):
v = v.to("cpu")
d2[k] = d2[k].to("cpu")
assert torch.equal(v, d2[k])
else:
assert v == d2[k]

View File

@ -31,14 +31,13 @@ def exam_state_dict(placement_policy, model_name: str):
zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set()
# ensure number of shards > 1
for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
for key, value in shard.items():
assert key not in accumulated_keys, f"key `{key}` is duplicated."
accumulated_keys.add(key)
assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')