ColossalAI/colossalai/utils/checkpoint_io/reader.py

132 lines
4.7 KiB
Python

import os
from abc import ABC, abstractmethod
from collections import Counter
from typing import Dict, Generator, List, Optional, Tuple
import torch
from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
from .meta import ParamDistMeta
from .utils import is_duplicated_list
class CheckpointReader(ABC):
def __init__(self, base_name: str) -> None:
super().__init__()
self.base_name = base_name
self.meta_list = []
@abstractmethod
def read(self, name: str) -> dict:
pass
@abstractmethod
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
pass
@abstractmethod
def load_model(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_models(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_others(self) -> dict:
pass
class DiskCheckpointReader(CheckpointReader):
def __init__(self, base_name: str) -> None:
super().__init__(base_name)
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
global_meta = self.read(GLOBAL_META_FILE_NAME)
for meta_file_name in global_meta['meta']:
meta = self.read(meta_file_name)
if meta.get('dist_meta', None) is None:
# only global checkpoint can have empty dist_meta
assert len(global_meta['meta']) == 1
self.meta_list.append(meta)
def read(self, name: str) -> dict:
return torch.load(os.path.join(self.base_name, name))
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
None), meta.get('paired_os', None))
for meta in self.meta_list]
dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
# reduce param_count
param_count = Counter(p for params in params_list for p in params)
# validate param_to_os
assert is_duplicated_list(param_to_os_list)
assert is_duplicated_list(paired_os_list)
return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
meta = self.meta_list[rank]
checkpoint_names = meta.get(shard_type, [])
for name in checkpoint_names:
yield self.read(name)
def load_model(self, rank: int) -> Generator[dict, None, None]:
return self._load_shard('model', rank)
def load_models(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
model_checkpoint_names = meta.get('model', [])
if indices[i] < len(model_checkpoint_names):
shards[i] = self.read(model_checkpoint_names[indices[i]])
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
param_groups = None
for shard in self._load_shard('optimizer', rank):
if param_groups is None:
param_groups = shard['param_groups']
else:
shard['param_groups'] = param_groups
yield shard
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
param_groups = []
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
optimizer_checkpoint_names = meta.get('optimizer', [])
if indices[i] < len(optimizer_checkpoint_names):
shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
if indices[i] == 0:
param_groups.append(shards[i]['param_groups'])
else:
shards[i]['param_groups'] = param_groups[i]
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_others(self) -> dict:
return self.read(OTHER_CKPT_FILE_NAME)