import os from abc import ABC, abstractmethod from collections import Counter from typing import Dict, Generator, List, Optional, Tuple import torch from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME from .meta import ParamDistMeta from .utils import is_duplicated_list class CheckpointReader(ABC): def __init__(self, base_name: str) -> None: super().__init__() self.base_name = base_name self.meta_list = [] @abstractmethod def read(self, name: str) -> dict: pass @abstractmethod def load_meta( self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: pass @abstractmethod def load_model(self, rank: int) -> Generator[dict, None, None]: pass @abstractmethod def load_models(self) -> Generator[Dict[int, dict], None, None]: pass @abstractmethod def load_optimizer(self, rank: int) -> Generator[dict, None, None]: pass @abstractmethod def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: pass @abstractmethod def load_others(self) -> dict: pass class DiskCheckpointReader(CheckpointReader): def __init__(self, base_name: str) -> None: super().__init__(base_name) assert os.path.isdir(base_name), f'"{base_name}" is not a directory' global_meta = self.read(GLOBAL_META_FILE_NAME) for meta_file_name in global_meta['meta']: meta = self.read(meta_file_name) if meta.get('dist_meta', None) is None: # only global checkpoint can have empty dist_meta assert len(global_meta['meta']) == 1 self.meta_list.append(meta) def read(self, name: str) -> dict: return torch.load(os.path.join(self.base_name, name)) def load_meta( self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os', None), meta.get('paired_os', None)) for meta in self.meta_list] dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos) # reduce param_count param_count = Counter(p for params in params_list for p in params) # validate param_to_os assert is_duplicated_list(param_to_os_list) assert is_duplicated_list(paired_os_list) return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0] def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]: meta = self.meta_list[rank] checkpoint_names = meta.get(shard_type, []) for name in checkpoint_names: yield self.read(name) def load_model(self, rank: int) -> Generator[dict, None, None]: return self._load_shard('model', rank) def load_models(self) -> Generator[Dict[int, dict], None, None]: indices = [0] * len(self.meta_list) while True: shards = {} for i, meta in enumerate(self.meta_list): model_checkpoint_names = meta.get('model', []) if indices[i] < len(model_checkpoint_names): shards[i] = self.read(model_checkpoint_names[indices[i]]) indices[i] += 1 if len(shards) > 0: yield shards else: break def load_optimizer(self, rank: int) -> Generator[dict, None, None]: param_groups = None for shard in self._load_shard('optimizer', rank): if param_groups is None: param_groups = shard['param_groups'] else: shard['param_groups'] = param_groups yield shard def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: indices = [0] * len(self.meta_list) param_groups = [] while True: shards = {} for i, meta in enumerate(self.meta_list): optimizer_checkpoint_names = meta.get('optimizer', []) if indices[i] < len(optimizer_checkpoint_names): shards[i] = self.read(optimizer_checkpoint_names[indices[i]]) if indices[i] == 0: param_groups.append(shards[i]['param_groups']) else: shards[i]['param_groups'] = param_groups[i] indices[i] += 1 if len(shards) > 0: yield shards else: break def load_others(self) -> dict: return self.read(OTHER_CKPT_FILE_NAME)