ColossalAI/colossalai/inference/pipeline/policy/llama_ppinfer.py

51 lines
1.8 KiB
Python
Raw Normal View History

[Pipeline Inference] Sync pipeline inference branch to main (#4820) * [pipeline inference] pipeline inference (#4492) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * fix CI * add cache clear * fix code error * fix typo * [Pipeline inference] Modify to tieweight (#4599) * add pp stage manager as circle stage * fix a bug when create process group * add ppinfer basic framework * add micro batch manager and support kvcache-pp gpt2 fwd * add generate schedule * use mb size to control mb number * support generate with kv cache * add output, remove unused code * add test * reuse shardformer to build model * refactor some code and use the same attribute name of hf * fix review and add test for generation * remove unused file * modify the way of saving newtokens * modify to tieweight * modify test * remove unused file * solve review * add docstring * [Pipeline inference] support llama pipeline inference (#4647) * support llama pipeline inference * remove tie weight operation * [pipeline inference] Fix the blocking of communication when ppsize is 2 (#4708) * add benchmark verbose * fix export tokens * fix benchmark verbose * add P2POp style to do p2p communication * modify schedule as p2p type when ppsize is 2 * remove unused code and add docstring * [Pipeline inference] Refactor code, add docsting, fix bug (#4790) * add benchmark script * update argparse * fix fp16 load * refactor code style * add docstring * polish code * fix test bug * [Pipeline inference] Add pipeline inference docs (#4817) * add readme doc * add a ico * Add performance * update table of contents * refactor code (#4873)
2023-10-11 03:40:06 +00:00
from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaPolicy
from ..modeling.llama import LlamaPipelineForwards
class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import LlamaForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(model_cls=LlamaForCausalLM,
new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_first_stage():
held_layers.append(self.model.lm_head)
return held_layers