ColossalAI/colossalai/inference/pipeline/policy/llama_ppinfer.py

51 lines
1.8 KiB
Python

from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaPolicy
from ..modeling.llama import LlamaPipelineForwards
class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import LlamaForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(model_cls=LlamaForCausalLM,
new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_first_stage():
held_layers.append(self.model.lm_head)
return held_layers