ColossalAI/colossalai/shardformer/policies/auto_policy.py

269 lines
12 KiB
Python
Raw Normal View History

import importlib
from dataclasses import dataclass
import torch.nn as nn
2023-07-05 07:13:00 +00:00
from .base_policy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
@dataclass
class PolicyLocation:
"""
PolicyLocation describes the location of a policy class.
Args:
file_name (str): The file name of the policy under colossalai.shardformer.policies
class_name (str): The class name of the policy class
"""
file_name: str
class_name: str
# we don't want to import all policies here
# as each policy file imports its own model zoo library
# we will allow the user to only import the policy file needed
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel": PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
"transformers.models.bert.modeling_bert.BertForPreTraining": PolicyLocation(
file_name="bert", class_name="BertForPreTrainingPolicy"
),
"transformers.models.bert.modeling_bert.BertLMHeadModel": PolicyLocation(
file_name="bert", class_name="BertLMHeadModelPolicy"
),
"transformers.models.bert.modeling_bert.BertForMaskedLM": PolicyLocation(
file_name="bert", class_name="BertForMaskedLMPolicy"
),
"transformers.models.bert.modeling_bert.BertForSequenceClassification": PolicyLocation(
file_name="bert", class_name="BertForSequenceClassificationPolicy"
),
"transformers.models.bert.modeling_bert.BertForTokenClassification": PolicyLocation(
file_name="bert", class_name="BertForTokenClassificationPolicy"
),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction": PolicyLocation(
file_name="bert", class_name="BertForNextSentencePredictionPolicy"
),
"transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation(
file_name="bert", class_name="BertForMultipleChoicePolicy"
),
"transformers.models.bert.modeling_bert.BertForQuestionAnswering": PolicyLocation(
file_name="bert", class_name="BertForQuestionAnsweringPolicy"
),
# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
file_name="llama", class_name="LlamaModelPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
file_name="llama", class_name="LlamaForCausalLMPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification": PolicyLocation(
file_name="llama", class_name="LlamaForSequenceClassificationPolicy"
),
# T5
"transformers.models.t5.modeling_t5.T5Model": PolicyLocation(file_name="t5", class_name="T5ModelPolicy"),
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration": PolicyLocation(
file_name="t5", class_name="T5ForConditionalGenerationPolicy"
),
"transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
"transformers.models.t5.modeling_t5.T5ForTokenClassification": PolicyLocation(
file_name="t5", class_name="T5ForTokenClassificationPolicy"
),
# GPT2
"transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation(
file_name="gpt2", class_name="GPT2LMHeadModelPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": PolicyLocation(
file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": PolicyLocation(
file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": PolicyLocation(
file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"
),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(
file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"
),
# GPTJ
"transformers.models.gptj.modeling_gptj.GPTJModel": PolicyLocation(file_name="gptj", class_name="GPTJModelPolicy"),
"transformers.models.gptj.modeling_gptj.GPTJForCausalLM": PolicyLocation(
file_name="gptj", class_name="GPTJForCausalLMPolicy"
),
"transformers.models.gptj.modeling_gptj.GPTJForQuestionAnswering": PolicyLocation(
file_name="gptj", class_name="GPTJForQuestionAnsweringPolicy"
),
"transformers.models.gptj.modeling_gptj.GPTJForSequenceClassification": PolicyLocation(
file_name="gptj", class_name="GPTJForSequenceClassificationPolicy"
),
# ViT
"transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
"transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation(
file_name="vit", class_name="ViTForImageClassificationPolicy"
),
"transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": PolicyLocation(
file_name="vit", class_name="ViTForMaskedImageModelingPolicy"
),
# OPT
"transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
"transformers.models.opt.modeling_opt.OPTForCausalLM": PolicyLocation(
file_name="opt", class_name="OPTForCausalLMPolicy"
),
"transformers.models.opt.modeling_opt.OPTForSequenceClassification": PolicyLocation(
file_name="opt", class_name="OPTForSequenceClassificationPolicy"
),
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering": PolicyLocation(
file_name="opt", class_name="OPTForQuestionAnsweringPolicy"
),
# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
file_name="bloom", class_name="BloomModelPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
file_name="bloom", class_name="BloomForCausalLMPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": PolicyLocation(
file_name="bloom", class_name="BloomForSequenceClassificationPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForTokenClassification": PolicyLocation(
file_name="bloom", class_name="BloomForTokenClassificationPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation(
file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"
),
# Whisper
"transformers.models.whisper.modeling_whisper.WhisperModel": PolicyLocation(
file_name="whisper", class_name="WhisperModelPolicy"
),
"transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": PolicyLocation(
file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"
),
"transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": PolicyLocation(
file_name="whisper", class_name="WhisperForAudioClassificationPolicy"
),
# Sam
"transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
# Blip2
"transformers.models.blip_2.modeling_blip_2.Blip2Model": PolicyLocation(
file_name="blip2", class_name="Blip2ModelPolicy"
),
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": PolicyLocation(
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
),
# ChatGLM
"transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMModelPolicy"
),
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Deepseek
"transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
file_name="deepseek", class_name="DeepseekModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
),
# Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForCausalLM": PolicyLocation(
file_name="falcon", class_name="FalconForCausalLMPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForSequenceClassification": PolicyLocation(
file_name="falcon", class_name="FalconForSequenceClassificationPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForTokenClassification": PolicyLocation(
file_name="falcon", class_name="FalconForTokenClassificationPolicy"
),
"transformers.models.falcon.modeling_falcon.FalconForQuestionAnswering": PolicyLocation(
file_name="falcon", class_name="FalconForQuestionAnsweringPolicy"
),
[MoE/ZeRO] Moe refactor with zero refactor (#5821) * [moe] removed openmoe-coupled code and rectify mixstral code (#5471) * [Feauture] MoE refractor; Intergration with Mixtral (#5682) * cherry pick from refractor-moe branch * tests passed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support ep + zero --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add mixtral auto policy & move pipeline forward code to modeling folder * [moe refactor] modify kernel test without Route Class * [moe refactor] add moe tensor test path environment variable to github workflow * fix typos * fix moe test bug due to the code rebase * [moe refactor] fix moe zero test, and little bug in low level zero * fix typo * add moe tensor path to github workflow * remove some useless code * fix typo & unify global variable XX_AXIS logic without using -1 * fix typo & prettifier the code * remove print code & support zero 2 test * remove useless code * reanme function * fix typo * fix typo * Further improve the test code * remove print code * [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test * [moe refactor] skip some unit test which will be refactored later * [moe refactor] fix unit import error * [moe refactor] fix circular import issues * [moe refactor] remove debug code * [moe refactor] update github workflow * [moe/zero] refactor low level optimizer (#5767) * [zero] refactor low level optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] MoE refactor with newest version of ZeRO (#5801) * [zero] remove redundant members in BucketStore (#5802) * [zero] align api with previous version * [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * [hotfix]Solve the compatibility issue of zero refactor (#5823) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * Modify function parameter names to resolve compatibility issues * [zero] fix missing hook removal (#5824) * [MoE] Resolve .github conflict (#5829) * [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 * [release] update version (#5752) * [release] update version * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [test] fix ddp plugin test * [test] fix gptj and rpc test * [devops] fix cuda ext compatibility * [inference] fix flash decoding test * [inference] fix flash decoding test * fix (#5765) * [test] Fix/fix testcase (#5770) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [Hotfix] Add missing init file in inference.executor (#5774) * [CI/tests] simplify some test case to reduce testing time (#5755) * [ci/tests] simplify some test case to reduce testing time * [ci/tests] continue to remove test case to reduce ci time cost * restore some test config * [ci/tests] continue to reduce ci time cost * [misc] update dockerfile (#5776) * [misc] update dockerfile * [misc] update dockerfile * [devops] fix docker ci (#5780) * [Inference]Add Streaming LLM (#5745) * Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist * [hotfix] fix llama flash attention forward (#5777) * [misc] Accelerate CI for zero and dist optim (#5758) * remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Test/CI] remove test cases to reduce CI duration (#5753) * [test] smaller gpt2 test case * [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py * [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py * [test] reduce test cases tests/test_zero/test_gemini/test_optim.py * Revert "[test] smaller gpt2 test case" Some tests might depend on the size of model (num of chunks) This reverts commit df705a5210b8901645992adf276e320e48766ebf. * [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py * [CI] smaller test model for two mwo the two modifid cases * [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there * [hotfix] fix testcase in test_fx/test_tracer (#5779) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [fix] fix test_deepfm_model & test_dlrf_model; * [fix] fix test_hf_albert & test_hf_gpt; * [gemini] optimize reduce scatter d2h copy (#5760) * [gemini] optimize reduce scatter d2h copy * [fix] fix missing reduce variable * [refactor] remove legacy async reduce scatter code * [gemini] missing sync * Revert "[refactor] remove legacy async reduce scatter code" This reverts commit 58ad76d4665032bbe548d066116d1c572ce98979. * [gemini] further optimize with async all reduce * [fix] pass flag from manager to chunk * Allow building cuda extension without a device. (#5535) Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are. * [misc] fix dist logger (#5782) * [install]fix setup (#5786) * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update requirements (#5787) * [shardformer] fix import (#5788) * upgrade colossal-chat support tp_group>1, add sp for sft * upgrade ppo dpo rm script * run pre-commit * moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy * fix training script * fix ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix transformers version * remove duplicated test * fix datasets version * remove models that require huggingface auth from ci * remove local data path * update ci * remove baichuan from template test due to transformer version conflict * merge * Refactor modeling by adding attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix tests and naming Signed-off-by: char-1ee <xingjianli59@gmail.com> * Pass inference model shard configs for module init Signed-off-by: char-1ee <xingjianli59@gmail.com> * Clean up Signed-off-by: char-1ee <xingjianli59@gmail.com> * replace the customized dataloader setup with the build-in one * replace the customized dataloader setup with the build-in one * Remove flash attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * fix readme * Fix test import Signed-off-by: char-1ee <xingjianli59@gmail.com> * update sft trainning script * [Inference]refactor baichuan (#5791) * refactor baichuan * remove unused code and add TODO for lazyinit * [test] fix chatglm test kit (#5793) * [shardformer] fix modeling of bloom and falcon (#5796) * [test] fix qwen2 pytest distLarge (#5797) * [Inference] Fix flash-attn import and add model test (#5794) * Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> * [Gemini] Use async stream to prefetch and h2d data moving (#5781) * use async stream to prefetch and h2d data moving * Remove redundant code * [gemini] quick fix on possible async operation (#5803) * [gemini] quick fix on possible async operation * [gemini] quick fix on possible async operation * [shardformer] upgrade transformers to 4.39.3 (#5815) * [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com> * Support 4d parallel + flash attention (#5789) * support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> * [zero] fix hook bug * [zero] add low level optimizer back (#5839) * [zero] fix param & refactor * [zero] add back original low level opt * [zero] remove moe related * [zero] pass zero tests * [zero] refactor * [chore] add del func back * [zero] comments and naming (#5840) * [zero] modify api (#5843) * [zero] modify api * [test] remove _grad_store access in tests * [test] fix (#5857) * [CI] skip openmoe CI check * [CI] fox pre-commit * [zero] remove redundant memebr init (#5862) * [misc] remove useless code, modify the pg mesh implementation * [misc] remove useless code, modify the pg mesh implementation * [misc] use tempfile * resolve conflict with main branch * [misc] use tempfile in test_moe_checkpoint.py * [misc] remove useless code, add assertion about sequence parallel, move logger into function * [misc] remove useless code --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
2024-06-28 06:00:08 +00:00
# mistral
"transformers.models.mistral.modeling_mistral.MistralModel": PolicyLocation(
file_name="mistral", class_name="MistralModelPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForCausalLM": PolicyLocation(
file_name="mistral", class_name="MistralForCausalLMPolicy"
),
"transformers.models.mistral.modeling_mistral.MistralForSequenceClassification": PolicyLocation(
file_name="mistral", class_name="MistralForSequenceClassificationPolicy"
),
[MoE/ZeRO] Moe refactor with zero refactor (#5821) * [moe] removed openmoe-coupled code and rectify mixstral code (#5471) * [Feauture] MoE refractor; Intergration with Mixtral (#5682) * cherry pick from refractor-moe branch * tests passed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support ep + zero --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add mixtral auto policy & move pipeline forward code to modeling folder * [moe refactor] modify kernel test without Route Class * [moe refactor] add moe tensor test path environment variable to github workflow * fix typos * fix moe test bug due to the code rebase * [moe refactor] fix moe zero test, and little bug in low level zero * fix typo * add moe tensor path to github workflow * remove some useless code * fix typo & unify global variable XX_AXIS logic without using -1 * fix typo & prettifier the code * remove print code & support zero 2 test * remove useless code * reanme function * fix typo * fix typo * Further improve the test code * remove print code * [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test * [moe refactor] skip some unit test which will be refactored later * [moe refactor] fix unit import error * [moe refactor] fix circular import issues * [moe refactor] remove debug code * [moe refactor] update github workflow * [moe/zero] refactor low level optimizer (#5767) * [zero] refactor low level optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] MoE refactor with newest version of ZeRO (#5801) * [zero] remove redundant members in BucketStore (#5802) * [zero] align api with previous version * [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * [hotfix]Solve the compatibility issue of zero refactor (#5823) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * Modify function parameter names to resolve compatibility issues * [zero] fix missing hook removal (#5824) * [MoE] Resolve .github conflict (#5829) * [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 * [release] update version (#5752) * [release] update version * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [test] fix ddp plugin test * [test] fix gptj and rpc test * [devops] fix cuda ext compatibility * [inference] fix flash decoding test * [inference] fix flash decoding test * fix (#5765) * [test] Fix/fix testcase (#5770) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [Hotfix] Add missing init file in inference.executor (#5774) * [CI/tests] simplify some test case to reduce testing time (#5755) * [ci/tests] simplify some test case to reduce testing time * [ci/tests] continue to remove test case to reduce ci time cost * restore some test config * [ci/tests] continue to reduce ci time cost * [misc] update dockerfile (#5776) * [misc] update dockerfile * [misc] update dockerfile * [devops] fix docker ci (#5780) * [Inference]Add Streaming LLM (#5745) * Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist * [hotfix] fix llama flash attention forward (#5777) * [misc] Accelerate CI for zero and dist optim (#5758) * remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Test/CI] remove test cases to reduce CI duration (#5753) * [test] smaller gpt2 test case * [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py * [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py * [test] reduce test cases tests/test_zero/test_gemini/test_optim.py * Revert "[test] smaller gpt2 test case" Some tests might depend on the size of model (num of chunks) This reverts commit df705a5210b8901645992adf276e320e48766ebf. * [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py * [CI] smaller test model for two mwo the two modifid cases * [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there * [hotfix] fix testcase in test_fx/test_tracer (#5779) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [fix] fix test_deepfm_model & test_dlrf_model; * [fix] fix test_hf_albert & test_hf_gpt; * [gemini] optimize reduce scatter d2h copy (#5760) * [gemini] optimize reduce scatter d2h copy * [fix] fix missing reduce variable * [refactor] remove legacy async reduce scatter code * [gemini] missing sync * Revert "[refactor] remove legacy async reduce scatter code" This reverts commit 58ad76d4665032bbe548d066116d1c572ce98979. * [gemini] further optimize with async all reduce * [fix] pass flag from manager to chunk * Allow building cuda extension without a device. (#5535) Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are. * [misc] fix dist logger (#5782) * [install]fix setup (#5786) * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update requirements (#5787) * [shardformer] fix import (#5788) * upgrade colossal-chat support tp_group>1, add sp for sft * upgrade ppo dpo rm script * run pre-commit * moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy * fix training script * fix ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix transformers version * remove duplicated test * fix datasets version * remove models that require huggingface auth from ci * remove local data path * update ci * remove baichuan from template test due to transformer version conflict * merge * Refactor modeling by adding attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix tests and naming Signed-off-by: char-1ee <xingjianli59@gmail.com> * Pass inference model shard configs for module init Signed-off-by: char-1ee <xingjianli59@gmail.com> * Clean up Signed-off-by: char-1ee <xingjianli59@gmail.com> * replace the customized dataloader setup with the build-in one * replace the customized dataloader setup with the build-in one * Remove flash attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * fix readme * Fix test import Signed-off-by: char-1ee <xingjianli59@gmail.com> * update sft trainning script * [Inference]refactor baichuan (#5791) * refactor baichuan * remove unused code and add TODO for lazyinit * [test] fix chatglm test kit (#5793) * [shardformer] fix modeling of bloom and falcon (#5796) * [test] fix qwen2 pytest distLarge (#5797) * [Inference] Fix flash-attn import and add model test (#5794) * Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> * [Gemini] Use async stream to prefetch and h2d data moving (#5781) * use async stream to prefetch and h2d data moving * Remove redundant code * [gemini] quick fix on possible async operation (#5803) * [gemini] quick fix on possible async operation * [gemini] quick fix on possible async operation * [shardformer] upgrade transformers to 4.39.3 (#5815) * [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com> * Support 4d parallel + flash attention (#5789) * support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> * [zero] fix hook bug * [zero] add low level optimizer back (#5839) * [zero] fix param & refactor * [zero] add back original low level opt * [zero] remove moe related * [zero] pass zero tests * [zero] refactor * [chore] add del func back * [zero] comments and naming (#5840) * [zero] modify api (#5843) * [zero] modify api * [test] remove _grad_store access in tests * [test] fix (#5857) * [CI] skip openmoe CI check * [CI] fox pre-commit * [zero] remove redundant memebr init (#5862) * [misc] remove useless code, modify the pg mesh implementation * [misc] remove useless code, modify the pg mesh implementation * [misc] use tempfile * resolve conflict with main branch * [misc] use tempfile in test_moe_checkpoint.py * [misc] remove useless code, add assertion about sequence parallel, move logger into function * [misc] remove useless code --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
2024-06-28 06:00:08 +00:00
# mixtral
"transformers.models.mixtral.modeling_mixtral.MixtralModel": PolicyLocation(
file_name="mixtral", class_name="MixtralModelPolicy"
),
"transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
file_name="mixtral", class_name="MixtralForCausalLMPolicy"
),
# Qwen2
"transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
file_name="qwen2", class_name="Qwen2ModelPolicy"
),
"transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM": PolicyLocation(
file_name="qwen2", class_name="Qwen2ForCausalLMPolicy"
),
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
),
[MoE/ZeRO] Moe refactor with zero refactor (#5821) * [moe] removed openmoe-coupled code and rectify mixstral code (#5471) * [Feauture] MoE refractor; Intergration with Mixtral (#5682) * cherry pick from refractor-moe branch * tests passed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support ep + zero --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add mixtral auto policy & move pipeline forward code to modeling folder * [moe refactor] modify kernel test without Route Class * [moe refactor] add moe tensor test path environment variable to github workflow * fix typos * fix moe test bug due to the code rebase * [moe refactor] fix moe zero test, and little bug in low level zero * fix typo * add moe tensor path to github workflow * remove some useless code * fix typo & unify global variable XX_AXIS logic without using -1 * fix typo & prettifier the code * remove print code & support zero 2 test * remove useless code * reanme function * fix typo * fix typo * Further improve the test code * remove print code * [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test * [moe refactor] skip some unit test which will be refactored later * [moe refactor] fix unit import error * [moe refactor] fix circular import issues * [moe refactor] remove debug code * [moe refactor] update github workflow * [moe/zero] refactor low level optimizer (#5767) * [zero] refactor low level optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] MoE refactor with newest version of ZeRO (#5801) * [zero] remove redundant members in BucketStore (#5802) * [zero] align api with previous version * [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * [hotfix]Solve the compatibility issue of zero refactor (#5823) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * Modify function parameter names to resolve compatibility issues * [zero] fix missing hook removal (#5824) * [MoE] Resolve .github conflict (#5829) * [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 * [release] update version (#5752) * [release] update version * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [test] fix ddp plugin test * [test] fix gptj and rpc test * [devops] fix cuda ext compatibility * [inference] fix flash decoding test * [inference] fix flash decoding test * fix (#5765) * [test] Fix/fix testcase (#5770) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [Hotfix] Add missing init file in inference.executor (#5774) * [CI/tests] simplify some test case to reduce testing time (#5755) * [ci/tests] simplify some test case to reduce testing time * [ci/tests] continue to remove test case to reduce ci time cost * restore some test config * [ci/tests] continue to reduce ci time cost * [misc] update dockerfile (#5776) * [misc] update dockerfile * [misc] update dockerfile * [devops] fix docker ci (#5780) * [Inference]Add Streaming LLM (#5745) * Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist * [hotfix] fix llama flash attention forward (#5777) * [misc] Accelerate CI for zero and dist optim (#5758) * remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Test/CI] remove test cases to reduce CI duration (#5753) * [test] smaller gpt2 test case * [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py * [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py * [test] reduce test cases tests/test_zero/test_gemini/test_optim.py * Revert "[test] smaller gpt2 test case" Some tests might depend on the size of model (num of chunks) This reverts commit df705a5210b8901645992adf276e320e48766ebf. * [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py * [CI] smaller test model for two mwo the two modifid cases * [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there * [hotfix] fix testcase in test_fx/test_tracer (#5779) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [fix] fix test_deepfm_model & test_dlrf_model; * [fix] fix test_hf_albert & test_hf_gpt; * [gemini] optimize reduce scatter d2h copy (#5760) * [gemini] optimize reduce scatter d2h copy * [fix] fix missing reduce variable * [refactor] remove legacy async reduce scatter code * [gemini] missing sync * Revert "[refactor] remove legacy async reduce scatter code" This reverts commit 58ad76d4665032bbe548d066116d1c572ce98979. * [gemini] further optimize with async all reduce * [fix] pass flag from manager to chunk * Allow building cuda extension without a device. (#5535) Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are. * [misc] fix dist logger (#5782) * [install]fix setup (#5786) * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update requirements (#5787) * [shardformer] fix import (#5788) * upgrade colossal-chat support tp_group>1, add sp for sft * upgrade ppo dpo rm script * run pre-commit * moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy * fix training script * fix ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix transformers version * remove duplicated test * fix datasets version * remove models that require huggingface auth from ci * remove local data path * update ci * remove baichuan from template test due to transformer version conflict * merge * Refactor modeling by adding attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix tests and naming Signed-off-by: char-1ee <xingjianli59@gmail.com> * Pass inference model shard configs for module init Signed-off-by: char-1ee <xingjianli59@gmail.com> * Clean up Signed-off-by: char-1ee <xingjianli59@gmail.com> * replace the customized dataloader setup with the build-in one * replace the customized dataloader setup with the build-in one * Remove flash attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * fix readme * Fix test import Signed-off-by: char-1ee <xingjianli59@gmail.com> * update sft trainning script * [Inference]refactor baichuan (#5791) * refactor baichuan * remove unused code and add TODO for lazyinit * [test] fix chatglm test kit (#5793) * [shardformer] fix modeling of bloom and falcon (#5796) * [test] fix qwen2 pytest distLarge (#5797) * [Inference] Fix flash-attn import and add model test (#5794) * Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> * [Gemini] Use async stream to prefetch and h2d data moving (#5781) * use async stream to prefetch and h2d data moving * Remove redundant code * [gemini] quick fix on possible async operation (#5803) * [gemini] quick fix on possible async operation * [gemini] quick fix on possible async operation * [shardformer] upgrade transformers to 4.39.3 (#5815) * [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com> * Support 4d parallel + flash attention (#5789) * support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> * [zero] fix hook bug * [zero] add low level optimizer back (#5839) * [zero] fix param & refactor * [zero] add back original low level opt * [zero] remove moe related * [zero] pass zero tests * [zero] refactor * [chore] add del func back * [zero] comments and naming (#5840) * [zero] modify api (#5843) * [zero] modify api * [test] remove _grad_store access in tests * [test] fix (#5857) * [CI] skip openmoe CI check * [CI] fox pre-commit * [zero] remove redundant memebr init (#5862) * [misc] remove useless code, modify the pg mesh implementation * [misc] remove useless code, modify the pg mesh implementation * [misc] use tempfile * resolve conflict with main branch * [misc] use tempfile in test_moe_checkpoint.py * [misc] remove useless code, add assertion about sequence parallel, move logger into function * [misc] remove useless code --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
2024-06-28 06:00:08 +00:00
# command
2024-06-14 03:04:56 +00:00
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
file_name="command", class_name="CommandModelPolicy"
),
"transformers.models.cohere.modeling_cohere.CohereForCausalLM": PolicyLocation(
file_name="command", class_name="CommandForCausalLMPolicy"
),
}
def import_policy(policy_location: PolicyLocation) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)
def _fullname(obj):
"""
Return the full name of an object, including the module name.
"""
klass = obj.__class__
module = klass.__module__
if module == "builtins":
return klass.__qualname__ # avoid outputs like 'builtins.str'
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
module = f"{split_module[0]}.{split_module[-1]}"
return module + "." + klass.__qualname__
def get_autopolicy(model: nn.Module) -> Policy:
r"""
Return the auto policy for the model
Args:
model (:class:`nn.Module`): The model to get the auto policy
Return:
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)
return policy()