ColossalAI/colossalai/utils/checkpoint_io/writer.py

99 lines
3.9 KiB
Python

from abc import ABC, abstractmethod
from typing import Optional
from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME
import torch
import os
class CheckpointWriter(ABC):
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
super().__init__()
self.base_name = base_name
self.overwrite = overwrite
self.rank = rank
self.world_size = world_size
self.is_distributed = world_size > 1
self.is_main_process = rank == 0
@abstractmethod
def write(self, name: str, state_dict: dict) -> None:
pass
@abstractmethod
def save_model(self, model_checkpoint: dict) -> None:
pass
@abstractmethod
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
pass
@abstractmethod
def save_meta(self, meta_checkpoint: dict) -> None:
pass
@abstractmethod
def save_others(self, kwargs: dict) -> None:
pass
class DiskCheckpointWriter(CheckpointWriter):
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
super().__init__(base_name, overwrite, rank, world_size)
if not os.path.exists(base_name):
os.makedirs(base_name)
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
self.model_checkpoint_names = []
self.optimizer_checkpoint_names = []
self.is_meta_saved: bool = False
self._save_global_meta()
def write(self, name: str, state_dict: dict) -> None:
path = os.path.join(self.base_name, name)
if os.path.exists(path) and not self.overwrite:
raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)')
torch.save(state_dict, path)
def _save_global_meta(self) -> None:
if self.is_main_process:
global_meta = {'meta': []}
if self.is_distributed:
for i in range(self.world_size):
global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin'))
else:
global_meta['meta'].append(META_CKPT_FILE_NAME)
self.write(GLOBAL_META_FILE_NAME, global_meta)
def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str:
checkpoint_name = base_name
if self.is_distributed:
checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin')
if shard_idx is not None:
checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin')
return checkpoint_name
def save_model(self, model_checkpoint: dict) -> None:
assert not self.is_meta_saved, 'Cannot save model after saving meta'
name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names))
self.write(name, model_checkpoint)
self.model_checkpoint_names.append(name)
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
assert not self.is_meta_saved, 'Cannot save optimizer after saving meta'
name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names))
self.write(name, optimizer_checkpoint)
self.optimizer_checkpoint_names.append(name)
def save_meta(self, meta_checkpoint: dict) -> None:
if len(self.model_checkpoint_names) > 0:
meta_checkpoint['model'] = self.model_checkpoint_names
if len(self.optimizer_checkpoint_names) > 0:
meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names
self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint)
self.is_meta_saved = True
def save_others(self, kwargs: dict) -> None:
if self.is_main_process:
self.write(OTHER_CKPT_FILE_NAME, kwargs)