ColossalAI/colossalai/booster/plugin.py

47 lines
947 B
Python
Raw Normal View History

from typing import List, Tuple
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from colossalai.device.device_mesh import DeviceMesh
__all__ = ['Plugin']
class Plugin:
@property
def supported_devices(self) -> List[torch.device]:
pass
@property
def supported_precisions(self) -> List[str]:
pass
@property
def control_precision(self) -> bool:
pass
@property
def control_device(self) -> bool:
pass
@property
def support_no_sync(self) -> bool:
pass
def setup_model(self, model: nn.Module, device_mesh_pool: DeviceMesh) -> nn.Module:
pass
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
pass
def setup_dataloader(self, dataloader: DataLoader) -> DataLoader:
pass
@property
def device_mesh_shape(self) -> List[Tuple[int, ...]]:
pass