mirror of https://github.com/hpcaitech/ColossalAI
132 lines
4.7 KiB
Python
132 lines
4.7 KiB
Python
import os
|
|
from abc import ABC, abstractmethod
|
|
from collections import Counter
|
|
from typing import Dict, Generator, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
|
|
from .meta import ParamDistMeta
|
|
from .utils import is_duplicated_list
|
|
|
|
|
|
class CheckpointReader(ABC):
|
|
|
|
def __init__(self, base_name: str) -> None:
|
|
super().__init__()
|
|
self.base_name = base_name
|
|
self.meta_list = []
|
|
|
|
@abstractmethod
|
|
def read(self, name: str) -> dict:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_meta(
|
|
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_model(self, rank: int) -> Generator[dict, None, None]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_models(self) -> Generator[Dict[int, dict], None, None]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def load_others(self) -> dict:
|
|
pass
|
|
|
|
|
|
class DiskCheckpointReader(CheckpointReader):
|
|
|
|
def __init__(self, base_name: str) -> None:
|
|
super().__init__(base_name)
|
|
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
|
|
global_meta = self.read(GLOBAL_META_FILE_NAME)
|
|
for meta_file_name in global_meta['meta']:
|
|
meta = self.read(meta_file_name)
|
|
if meta.get('dist_meta', None) is None:
|
|
# only global checkpoint can have empty dist_meta
|
|
assert len(global_meta['meta']) == 1
|
|
self.meta_list.append(meta)
|
|
|
|
def read(self, name: str) -> dict:
|
|
return torch.load(os.path.join(self.base_name, name))
|
|
|
|
def load_meta(
|
|
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
|
|
meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
|
|
None), meta.get('paired_os', None))
|
|
for meta in self.meta_list]
|
|
dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
|
|
# reduce param_count
|
|
param_count = Counter(p for params in params_list for p in params)
|
|
# validate param_to_os
|
|
assert is_duplicated_list(param_to_os_list)
|
|
assert is_duplicated_list(paired_os_list)
|
|
return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
|
|
|
|
def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
|
|
meta = self.meta_list[rank]
|
|
checkpoint_names = meta.get(shard_type, [])
|
|
for name in checkpoint_names:
|
|
yield self.read(name)
|
|
|
|
def load_model(self, rank: int) -> Generator[dict, None, None]:
|
|
return self._load_shard('model', rank)
|
|
|
|
def load_models(self) -> Generator[Dict[int, dict], None, None]:
|
|
indices = [0] * len(self.meta_list)
|
|
while True:
|
|
shards = {}
|
|
for i, meta in enumerate(self.meta_list):
|
|
model_checkpoint_names = meta.get('model', [])
|
|
if indices[i] < len(model_checkpoint_names):
|
|
shards[i] = self.read(model_checkpoint_names[indices[i]])
|
|
indices[i] += 1
|
|
if len(shards) > 0:
|
|
yield shards
|
|
else:
|
|
break
|
|
|
|
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
|
|
param_groups = None
|
|
for shard in self._load_shard('optimizer', rank):
|
|
if param_groups is None:
|
|
param_groups = shard['param_groups']
|
|
else:
|
|
shard['param_groups'] = param_groups
|
|
yield shard
|
|
|
|
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
|
|
indices = [0] * len(self.meta_list)
|
|
param_groups = []
|
|
while True:
|
|
shards = {}
|
|
for i, meta in enumerate(self.meta_list):
|
|
optimizer_checkpoint_names = meta.get('optimizer', [])
|
|
if indices[i] < len(optimizer_checkpoint_names):
|
|
shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
|
|
if indices[i] == 0:
|
|
param_groups.append(shards[i]['param_groups'])
|
|
else:
|
|
shards[i]['param_groups'] = param_groups[i]
|
|
indices[i] += 1
|
|
if len(shards) > 0:
|
|
yield shards
|
|
else:
|
|
break
|
|
|
|
def load_others(self) -> dict:
|
|
return self.read(OTHER_CKPT_FILE_NAME)
|