mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
1.9 KiB
75 lines
1.9 KiB
2 years ago
|
import shutil
|
||
|
import tempfile
|
||
|
from abc import ABC, abstractmethod
|
||
|
from typing import Dict, List, Type
|
||
|
|
||
|
from .reader import CheckpointReader, DiskCheckpointReader
|
||
|
from .writer import CheckpointWriter, DiskCheckpointWriter
|
||
|
|
||
|
_backends: Dict[str, Type['CheckpointIOBackend']] = {}
|
||
|
|
||
|
|
||
|
def register(name: str):
|
||
|
assert name not in _backends, f'"{name}" is registered'
|
||
|
|
||
|
def wrapper(cls):
|
||
|
_backends[name] = cls
|
||
|
return cls
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def get_backend(name: str) -> 'CheckpointIOBackend':
|
||
|
assert name in _backends, f'Unsupported backend "{name}"'
|
||
|
return _backends[name]()
|
||
|
|
||
|
|
||
|
class CheckpointIOBackend(ABC):
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
super().__init__()
|
||
|
self.temps: List[str] = []
|
||
|
|
||
|
@abstractmethod
|
||
|
def get_writer(self,
|
||
|
base_name: str,
|
||
|
overwrite: bool = False,
|
||
|
rank: int = 0,
|
||
|
world_size: int = 1) -> CheckpointWriter:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def get_reader(self, base_name: str) -> CheckpointReader:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def get_temp(self, base_name: str) -> str:
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def clean_temp(self) -> None:
|
||
|
pass
|
||
|
|
||
|
|
||
|
@register('disk')
|
||
|
class CheckpointDiskIO(CheckpointIOBackend):
|
||
|
|
||
|
def get_writer(self,
|
||
|
base_name: str,
|
||
|
overwrite: bool = False,
|
||
|
rank: int = 0,
|
||
|
world_size: int = 1) -> CheckpointWriter:
|
||
|
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
|
||
|
|
||
|
def get_reader(self, base_name: str) -> CheckpointReader:
|
||
|
return DiskCheckpointReader(base_name)
|
||
|
|
||
|
def get_temp(self, base_name: str) -> str:
|
||
|
temp_dir_name = tempfile.mkdtemp(dir=base_name)
|
||
|
self.temps.append(temp_dir_name)
|
||
|
return temp_dir_name
|
||
|
|
||
|
def clean_temp(self) -> None:
|
||
|
for temp_dir_name in self.temps:
|
||
|
shutil.rmtree(temp_dir_name)
|