mirror of https://github.com/hpcaitech/ColossalAI
99 lines
3.9 KiB
Python
99 lines
3.9 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME
|
|
import torch
|
|
import os
|
|
|
|
|
|
class CheckpointWriter(ABC):
|
|
|
|
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
|
|
super().__init__()
|
|
self.base_name = base_name
|
|
self.overwrite = overwrite
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
self.is_distributed = world_size > 1
|
|
self.is_main_process = rank == 0
|
|
|
|
@abstractmethod
|
|
def write(self, name: str, state_dict: dict) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_model(self, model_checkpoint: dict) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_meta(self, meta_checkpoint: dict) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_others(self, kwargs: dict) -> None:
|
|
pass
|
|
|
|
|
|
class DiskCheckpointWriter(CheckpointWriter):
|
|
|
|
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
|
|
super().__init__(base_name, overwrite, rank, world_size)
|
|
if not os.path.exists(base_name):
|
|
os.makedirs(base_name)
|
|
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
|
|
self.model_checkpoint_names = []
|
|
self.optimizer_checkpoint_names = []
|
|
self.is_meta_saved: bool = False
|
|
self._save_global_meta()
|
|
|
|
def write(self, name: str, state_dict: dict) -> None:
|
|
path = os.path.join(self.base_name, name)
|
|
if os.path.exists(path) and not self.overwrite:
|
|
raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)')
|
|
torch.save(state_dict, path)
|
|
|
|
def _save_global_meta(self) -> None:
|
|
if self.is_main_process:
|
|
global_meta = {'meta': []}
|
|
if self.is_distributed:
|
|
for i in range(self.world_size):
|
|
global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin'))
|
|
else:
|
|
global_meta['meta'].append(META_CKPT_FILE_NAME)
|
|
self.write(GLOBAL_META_FILE_NAME, global_meta)
|
|
|
|
def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str:
|
|
checkpoint_name = base_name
|
|
if self.is_distributed:
|
|
checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin')
|
|
if shard_idx is not None:
|
|
checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin')
|
|
return checkpoint_name
|
|
|
|
def save_model(self, model_checkpoint: dict) -> None:
|
|
assert not self.is_meta_saved, 'Cannot save model after saving meta'
|
|
name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names))
|
|
self.write(name, model_checkpoint)
|
|
self.model_checkpoint_names.append(name)
|
|
|
|
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
|
|
assert not self.is_meta_saved, 'Cannot save optimizer after saving meta'
|
|
name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names))
|
|
self.write(name, optimizer_checkpoint)
|
|
self.optimizer_checkpoint_names.append(name)
|
|
|
|
def save_meta(self, meta_checkpoint: dict) -> None:
|
|
if len(self.model_checkpoint_names) > 0:
|
|
meta_checkpoint['model'] = self.model_checkpoint_names
|
|
if len(self.optimizer_checkpoint_names) > 0:
|
|
meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names
|
|
self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint)
|
|
self.is_meta_saved = True
|
|
|
|
def save_others(self, kwargs: dict) -> None:
|
|
if self.is_main_process:
|
|
self.write(OTHER_CKPT_FILE_NAME, kwargs)
|