mirror of https://github.com/hpcaitech/ColossalAI
[booster] gemini plugin support shard checkpoint (#3610)
* gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint --------- Co-authored-by: luchen <luchen@luchendeMBP.lan> Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>pull/3684/head^2
parent
0f785cb1f3
commit
307894f74d
|
@ -1,6 +1,9 @@
|
|||
import random
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from pathlib import Path
|
||||
import os
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -20,6 +23,13 @@ from colossalai.utils import get_current_device
|
|||
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_base_filenames,
|
||||
get_shard_filename
|
||||
)
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||
|
||||
from .plugin_base import Plugin
|
||||
|
||||
__all__ = ['GeminiPlugin']
|
||||
|
@ -62,6 +72,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
|
||||
"""
|
||||
Save sharded model
|
||||
"""
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
if not self.coordinator.is_master():
|
||||
continue
|
||||
shard = shard_pair[0]
|
||||
shard_file = get_shard_filename(weights_name, idx)
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
|
||||
def load_sharded_model(self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
"""
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
class GeminiModel(ModelWrapper):
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ class CheckpointIO(ABC):
|
|||
# the existence of index file means it is a sharded checkpoint
|
||||
ckpt_path = Path(checkpoint)
|
||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||
|
||||
|
||||
# return the origin model instead of the unwrapped model
|
||||
origin_model = model
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from pathlib import Path
|
||||
from functools import reduce
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import gc
|
||||
from typing import Optional
|
||||
from typing import Optional, Iterator, OrderedDict, Tuple
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
|
@ -18,10 +18,9 @@ from .utils import (
|
|||
shard_checkpoint,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_into_model,
|
||||
add_variant
|
||||
get_shard_filename,
|
||||
get_base_filenames
|
||||
)
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
|
||||
__all__ = ['GeneralCheckpointIO']
|
||||
|
||||
|
@ -85,30 +84,32 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
|
||||
# shard checkpoint
|
||||
state_dict = model.state_dict()
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
||||
# Save the model
|
||||
for shard_file, shard in shards.items():
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
shard = shard_pair[0]
|
||||
shard_file = get_shard_filename(weights_name, idx)
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
# save index file
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
|
||||
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
|
||||
use_safetensors: bool = False, load_sub_module: bool = True):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
"""
|
||||
|
@ -122,17 +123,21 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
# read checkpoint index file
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
|
||||
missing_keys = ckpt_index_file.get_all_param_names()
|
||||
missing_keys = []
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
||||
load_state_dict_into_model(model, state_dict, missing_keys, strict)
|
||||
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if strict and len(missing_keys) > 0:
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
if strict:
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
if len(remain_keys) > 0:
|
||||
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
|
||||
', '.join('"{}"'.format(k) for k in missing_keys))
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
import os
|
||||
import json
|
||||
|
||||
from .utils import is_dtensor_checkpoint
|
||||
|
||||
|
@ -18,8 +20,8 @@ class CheckpointIndexFile:
|
|||
>>> index.export('new_index.json')
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.root_path = None
|
||||
def __init__(self, root_path=None) -> None:
|
||||
self.root_path = root_path
|
||||
self.metadata: dict = dict()
|
||||
self.weight_map: dict = dict()
|
||||
|
||||
|
@ -154,3 +156,13 @@ class CheckpointIndexFile:
|
|||
Get all the weight keys.
|
||||
"""
|
||||
return list(self.weight_map.keys())
|
||||
|
||||
def write_index_file(self, save_index_file):
|
||||
"""
|
||||
Wriete index file.
|
||||
"""
|
||||
save_index_file = os.path.join(self.root_path, save_index_file)
|
||||
index = {"metadata": self.metadata, "weight_map": self.weight_map}
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
|
||||
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
import re
|
||||
|
||||
|
@ -77,55 +77,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
|
|||
# ======================================
|
||||
# Helper functions for saving shard file
|
||||
# ======================================
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
|
||||
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
"""
|
||||
sharded_state_dicts = []
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
total_size = 0
|
||||
|
||||
for key, weight in state_dict.items():
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if type(weight) != DTensor:
|
||||
weight_size = calculate_tensor_size(weight)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
sharded_state_dicts.append(current_block)
|
||||
ret_block = current_block
|
||||
ret_block_size = current_block_size
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[key] = weight
|
||||
current_block_size += weight_size
|
||||
total_size += weight_size
|
||||
|
||||
if ret_block != None:
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
# Add the last block
|
||||
sharded_state_dicts.append(current_block)
|
||||
yield current_block, current_block_size
|
||||
|
||||
# If we only have one shard, we return it
|
||||
if len(sharded_state_dicts) == 1:
|
||||
return {weights_name: sharded_state_dicts[0]}, None
|
||||
|
||||
# Otherwise, let's build the index
|
||||
weight_map = {}
|
||||
shards = {}
|
||||
|
||||
for idx, shard in enumerate(sharded_state_dicts):
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
||||
shard_file = shard_file.replace(
|
||||
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
|
||||
)
|
||||
shards[shard_file] = shard
|
||||
for key in shard.keys():
|
||||
weight_map[key] = shard_file
|
||||
|
||||
# Add the metadata
|
||||
metadata = {"total_size": total_size}
|
||||
index = {"metadata": metadata, "weight_map": weight_map}
|
||||
return shards, index
|
||||
|
||||
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
||||
"""
|
||||
|
@ -146,7 +126,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
|
|||
else:
|
||||
return torch.load(checkpoint_file)
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
|
||||
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants.
|
||||
|
||||
|
@ -167,29 +147,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
|
|||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module: nn.Module, state_dict, prefix=""):
|
||||
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
|
||||
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
||||
# state_dict
|
||||
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||
module._load_from_state_dict(*args)
|
||||
if load_sub_module:
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".")
|
||||
|
||||
load(model, state_dict, "")
|
||||
load(model, state_dict, "", load_sub_module)
|
||||
del load
|
||||
|
||||
# deal with missing key
|
||||
if len(missing_keys) > 0:
|
||||
deleted_keys = []
|
||||
for key in missing_keys:
|
||||
if key not in sub_missing_keys:
|
||||
deleted_keys.append(key)
|
||||
for key in deleted_keys:
|
||||
missing_keys.remove(key)
|
||||
missing_keys = missing_keys.append(sub_missing_keys)
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
|
@ -417,3 +390,24 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
|
||||
|
||||
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
|
||||
"""
|
||||
generate base weight filenames
|
||||
"""
|
||||
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
|
||||
weights_name = add_variant(weights_name, variant)
|
||||
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
|
||||
save_index_file = add_variant(save_index_file, variant)
|
||||
|
||||
return weights_name, save_index_file
|
||||
|
||||
def get_shard_filename(weights_name: str, idx: int):
|
||||
"""
|
||||
get shard file name
|
||||
"""
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
|
||||
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
|
||||
return shard_file
|
|
@ -2,7 +2,7 @@ import itertools
|
|||
from collections import OrderedDict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -96,8 +96,35 @@ class ZeroDDP(ColoDDP):
|
|||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
|
||||
self._cast_buffers()
|
||||
|
||||
def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
|
||||
|
||||
r"""
|
||||
Args:
|
||||
memo: a memo to store the set of modules already added to the result
|
||||
prefix: a prefix that will be added to the name of the module
|
||||
remove_duplicate: whether to remove the duplicated module instances in the result
|
||||
or not
|
||||
"""
|
||||
|
||||
if memo is None:
|
||||
memo = set()
|
||||
self_non_persistent_set = set()
|
||||
if module not in memo:
|
||||
if remove_duplicate:
|
||||
memo.add(module)
|
||||
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
|
||||
for name, sub_module in module._modules.items():
|
||||
if sub_module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
|
||||
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
|
||||
return self_non_persistent_set
|
||||
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
"""
|
||||
|
@ -604,7 +631,7 @@ class ZeroDDP(ColoDDP):
|
|||
keep_vars: bool = False,
|
||||
max_shard_size: int = 1024,
|
||||
only_rank_0: bool = True,
|
||||
dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]:
|
||||
dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||
|
@ -644,9 +671,9 @@ class ZeroDDP(ColoDDP):
|
|||
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
|
||||
gathered_param = gathered_param_buffer.pop(fp32_param)
|
||||
|
||||
block = sharder.append(prefix + name, gathered_param)
|
||||
block, block_size = sharder.append(prefix + name, gathered_param)
|
||||
if block is not None:
|
||||
yield block
|
||||
yield block, block_size
|
||||
|
||||
del fp16_to_fp32
|
||||
del gathered_param_buffer
|
||||
|
@ -655,19 +682,19 @@ class ZeroDDP(ColoDDP):
|
|||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block = sharder.append(prefix + name, buffer)
|
||||
block, block_size = sharder.append(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block
|
||||
yield block, block_size
|
||||
# save extra states
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state",
|
||||
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||
extra_state = self.get_extra_state()
|
||||
block = sharder.append(extra_state_key, extra_state)
|
||||
block, block_size = sharder.append(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block
|
||||
yield block, block_size
|
||||
|
||||
yield sharder.current_block
|
||||
yield sharder.current_block, sharder.current_block_size
|
||||
|
||||
|
||||
class _StateDictSharder:
|
||||
|
@ -677,16 +704,18 @@ class _StateDictSharder:
|
|||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]:
|
||||
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
|
||||
tensor_size = calculate_tensor_size(tensor)
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if self.current_block_size + tensor_size > self.max_shard_size:
|
||||
ret_block = self.current_block
|
||||
ret_block_size = self.current_block_size
|
||||
self.current_block = OrderedDict()
|
||||
self.current_block_size = 0
|
||||
self.current_block[name] = tensor
|
||||
self.current_block_size += tensor_size
|
||||
return ret_block
|
||||
return ret_block, ret_block_size
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
|
|
@ -3,4 +3,4 @@ markers =
|
|||
cpu: tests which can run on CPU
|
||||
gpu: tests which requires a single GPU
|
||||
dist: tests which are run in a multi-GPU or multi-machine environment
|
||||
experiment: tests for experimental features
|
||||
experiment: tests for experimental features
|
||||
|
|
|
@ -1,16 +1,21 @@
|
|||
import tempfile
|
||||
import pytest
|
||||
import torch
|
||||
import logging
|
||||
from torch.optim import Adam
|
||||
from torchvision.models import resnet18
|
||||
from pathlib import Path
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
# ========
|
||||
# Note:
|
||||
# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
|
||||
|
@ -83,7 +88,6 @@ def test_sharded_checkpoint(use_safetensors: bool):
|
|||
suffix = ".bin"
|
||||
WEIGHTS_INDEX_NAME = "model.bin.index.json"
|
||||
|
||||
# model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix)
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||
|
||||
|
@ -104,6 +108,87 @@ def test_sharded_checkpoint(use_safetensors: bool):
|
|||
recursive_check(model.state_dict(), new_model.state_dict())
|
||||
recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', ['bert'])
|
||||
@parameterize('use_safetensors', [True, False])
|
||||
def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool):
|
||||
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification
|
||||
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, *_ = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
bert_model = model_builder()
|
||||
bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name)
|
||||
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
bert_model = ZeroDDP(bert_model, gemini_manager)
|
||||
bert_model.train()
|
||||
|
||||
ckpt_io = GeminiCheckpointIO()
|
||||
if ckpt_io.coordinator.is_master():
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors)
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name)
|
||||
recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict())
|
||||
|
||||
model_ckpt_dir.cleanup()
|
||||
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
@parameterize('use_safetensors', [True, False])
|
||||
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, *_ = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
new_model = model_builder()
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
model.train()
|
||||
|
||||
new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100)
|
||||
new_chunk_manager = ChunkManager(new_config_dict)
|
||||
new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager)
|
||||
new_model = ZeroDDP(new_model, new_gemini_manager)
|
||||
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
ckpt_io = GeminiCheckpointIO()
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors)
|
||||
|
||||
# load model
|
||||
if ckpt_io.coordinator.is_master():
|
||||
ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True)
|
||||
model_dict = model.state_dict(only_rank_0=True)
|
||||
new_model_dict = new_model.state_dict(only_rank_0=True)
|
||||
recursive_check(model_dict, new_model_dict)
|
||||
|
||||
model_ckpt_dir.cleanup()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_state_dict()
|
||||
hf_load_colossalai_checkpoint()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
# do recursive check for the optimizer state dict
|
||||
# if the value is a dict, compare its values
|
||||
|
@ -117,10 +202,14 @@ def recursive_check(d1, d2):
|
|||
elif isinstance(v, list):
|
||||
for i in range(len(v)):
|
||||
if isinstance(v[i], torch.Tensor):
|
||||
v[i] = v[i].to("cpu")
|
||||
d2[k][i] = d2[k][i].to("cpu")
|
||||
assert torch.equal(v[i], d2[k][i])
|
||||
else:
|
||||
assert v[i] == d2[k][i]
|
||||
elif isinstance(v, torch.Tensor):
|
||||
v = v.to("cpu")
|
||||
d2[k] = d2[k].to("cpu")
|
||||
assert torch.equal(v, d2[k])
|
||||
else:
|
||||
assert v == d2[k]
|
||||
|
|
|
@ -31,14 +31,13 @@ def exam_state_dict(placement_policy, model_name: str):
|
|||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
accumulated_keys = set()
|
||||
# ensure number of shards > 1
|
||||
for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||
for key, value in shard.items():
|
||||
assert key not in accumulated_keys, f"key `{key}` is duplicated."
|
||||
accumulated_keys.add(key)
|
||||
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
||||
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
|
Loading…
Reference in New Issue