mirror of https://github.com/hpcaitech/ColossalAI
23 lines
730 B
Python
23 lines
730 B
Python
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||
|
|
||
|
from torch import Tensor
|
||
|
from torch.nn import Module, Parameter
|
||
|
|
||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||
|
|
||
|
from .base import Policy
|
||
|
from .bert import BertModel, BertModelPolicy
|
||
|
|
||
|
POLICY_MAP: Dict[Type[Module], Type[Policy]] = {
|
||
|
BertModel: BertModelPolicy,
|
||
|
}
|
||
|
|
||
|
|
||
|
def pipeline_parallelize(
|
||
|
model: Module,
|
||
|
stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
|
||
|
if type(model) not in POLICY_MAP:
|
||
|
raise NotImplementedError(f"Policy for {type(model)} not implemented")
|
||
|
policy = POLICY_MAP[type(model)](stage_manager)
|
||
|
return policy.parallelize_model(model)
|