import shutil import tempfile from abc import ABC, abstractmethod from typing import Dict, List, Type from .reader import CheckpointReader, DiskCheckpointReader from .writer import CheckpointWriter, DiskCheckpointWriter _backends: Dict[str, Type['CheckpointIOBackend']] = {} def register(name: str): assert name not in _backends, f'"{name}" is registered' def wrapper(cls): _backends[name] = cls return cls return wrapper def get_backend(name: str) -> 'CheckpointIOBackend': assert name in _backends, f'Unsupported backend "{name}"' return _backends[name]() class CheckpointIOBackend(ABC): def __init__(self) -> None: super().__init__() self.temps: List[str] = [] @abstractmethod def get_writer(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> CheckpointWriter: pass @abstractmethod def get_reader(self, base_name: str) -> CheckpointReader: pass @abstractmethod def get_temp(self, base_name: str) -> str: pass @abstractmethod def clean_temp(self) -> None: pass @register('disk') class CheckpointDiskIO(CheckpointIOBackend): def get_writer(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> CheckpointWriter: return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size) def get_reader(self, base_name: str) -> CheckpointReader: return DiskCheckpointReader(base_name) def get_temp(self, base_name: str) -> str: temp_dir_name = tempfile.mkdtemp(dir=base_name) self.temps.append(temp_dir_name) return temp_dir_name def clean_temp(self) -> None: for temp_dir_name in self.temps: shutil.rmtree(temp_dir_name)