mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
947 B
46 lines
947 B
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.optim import Optimizer |
|
from torch.utils.data import DataLoader |
|
|
|
from colossalai.device.device_mesh import DeviceMesh |
|
|
|
__all__ = ['Plugin'] |
|
|
|
|
|
class Plugin: |
|
|
|
@property |
|
def supported_devices(self) -> List[torch.device]: |
|
pass |
|
|
|
@property |
|
def supported_precisions(self) -> List[str]: |
|
pass |
|
|
|
@property |
|
def control_precision(self) -> bool: |
|
pass |
|
|
|
@property |
|
def control_device(self) -> bool: |
|
pass |
|
|
|
@property |
|
def support_no_sync(self) -> bool: |
|
pass |
|
|
|
def setup_model(self, model: nn.Module, device_mesh_pool: DeviceMesh) -> nn.Module: |
|
pass |
|
|
|
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: |
|
pass |
|
|
|
def setup_dataloader(self, dataloader: DataLoader) -> DataLoader: |
|
pass |
|
|
|
@property |
|
def device_mesh_shape(self) -> List[Tuple[int, ...]]: |
|
pass
|
|
|