ColossalAI/colossalai/utils/checkpoint_io/reader.py

132 lines
4.7 KiB
Python
Raw Normal View History

import os
from abc import ABC, abstractmethod
from collections import Counter
from typing import Dict, Generator, List, Optional, Tuple
import torch
from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
from .meta import ParamDistMeta
from .utils import is_duplicated_list
class CheckpointReader(ABC):
def __init__(self, base_name: str) -> None:
super().__init__()
self.base_name = base_name
self.meta_list = []
@abstractmethod
def read(self, name: str) -> dict:
pass
@abstractmethod
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
pass
@abstractmethod
def load_model(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_models(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_others(self) -> dict:
pass
class DiskCheckpointReader(CheckpointReader):
def __init__(self, base_name: str) -> None:
super().__init__(base_name)
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
global_meta = self.read(GLOBAL_META_FILE_NAME)
for meta_file_name in global_meta['meta']:
meta = self.read(meta_file_name)
if meta.get('dist_meta', None) is None:
# only global checkpoint can have empty dist_meta
assert len(global_meta['meta']) == 1
self.meta_list.append(meta)
def read(self, name: str) -> dict:
return torch.load(os.path.join(self.base_name, name))
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
None), meta.get('paired_os', None))
for meta in self.meta_list]
dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
# reduce param_count
param_count = Counter(p for params in params_list for p in params)
# validate param_to_os
assert is_duplicated_list(param_to_os_list)
assert is_duplicated_list(paired_os_list)
return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
meta = self.meta_list[rank]
checkpoint_names = meta.get(shard_type, [])
for name in checkpoint_names:
yield self.read(name)
def load_model(self, rank: int) -> Generator[dict, None, None]:
return self._load_shard('model', rank)
def load_models(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
model_checkpoint_names = meta.get('model', [])
if indices[i] < len(model_checkpoint_names):
shards[i] = self.read(model_checkpoint_names[indices[i]])
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
param_groups = None
for shard in self._load_shard('optimizer', rank):
if param_groups is None:
param_groups = shard['param_groups']
else:
shard['param_groups'] = param_groups
yield shard
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
param_groups = []
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
optimizer_checkpoint_names = meta.get('optimizer', [])
if indices[i] < len(optimizer_checkpoint_names):
shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
if indices[i] == 0:
param_groups.append(shards[i]['param_groups'])
else:
shards[i]['param_groups'] = param_groups[i]
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_others(self) -> dict:
return self.read(OTHER_CKPT_FILE_NAME)