2024-02-19 09:18:20 +00:00
|
|
|
from typing import Dict, List, Union
|
2023-12-01 09:31:31 +00:00
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
import torch
|
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
2024-02-07 09:55:48 +00:00
|
|
|
from transformers.generation import GenerationConfig
|
2023-12-25 04:15:15 +00:00
|
|
|
|
2024-02-19 09:18:20 +00:00
|
|
|
from colossalai.inference.batch_bucket import BatchBucket
|
2023-12-25 04:15:15 +00:00
|
|
|
from colossalai.inference.config import InferenceConfig
|
2024-01-26 06:00:10 +00:00
|
|
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
2024-05-14 02:00:55 +00:00
|
|
|
from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
from colossalai.inference.struct import DiffusionSequence, RequestStatus, Sequence
|
2024-05-14 02:00:55 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
logger = get_dist_logger(__name__)
|
2024-01-09 05:52:53 +00:00
|
|
|
|
2024-02-02 06:31:10 +00:00
|
|
|
__all__ = ["RunningList", "RequestHandler"]
|
|
|
|
|
2023-12-01 09:31:31 +00:00
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
class RunningList:
|
|
|
|
"""
|
|
|
|
RunningList is an structure for recording the running sequences, contains prefill and decoding list.
|
|
|
|
Prefilling samples will be hold until the actual ratio of prefill samples versus decoding samples exceeds ratio.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
prefill_ratio: (float) A ratio for determing whether to perform prefill or not.
|
2024-02-19 09:18:20 +00:00
|
|
|
_prefill (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
|
|
|
|
_decoding (OrderedDict[Sequence]): Mapping of sequence uid -> Sequence.
|
2023-12-25 04:15:15 +00:00
|
|
|
"""
|
|
|
|
|
2024-02-19 09:18:20 +00:00
|
|
|
def __init__(self, prefill_ratio: int, prefill: List[Sequence] = None) -> None:
|
2023-12-25 04:15:15 +00:00
|
|
|
self.prefill_ratio = prefill_ratio
|
2024-02-19 09:18:20 +00:00
|
|
|
self._decoding: Dict[int, Sequence] = dict()
|
|
|
|
self._prefill: Dict[int, Sequence] = (
|
|
|
|
dict({seq.request_id: seq for seq in self._prefill}) if prefill is not None else dict()
|
|
|
|
)
|
2023-12-25 04:15:15 +00:00
|
|
|
|
2024-02-19 09:18:20 +00:00
|
|
|
@property
|
|
|
|
def decoding(self):
|
|
|
|
return list(self._decoding.values())
|
|
|
|
|
|
|
|
@property
|
|
|
|
def prefill(self):
|
|
|
|
return list(self._prefill.values())
|
|
|
|
|
|
|
|
@property
|
|
|
|
def prefill_seq_num(self):
|
|
|
|
return len(self._prefill)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def decoding_seq_num(self):
|
|
|
|
return len(self._decoding)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def total_seq_num(self):
|
|
|
|
return self.prefill_seq_num + self.decoding_seq_num
|
2023-12-25 04:15:15 +00:00
|
|
|
|
2024-02-19 09:18:20 +00:00
|
|
|
def append(self, seq: Sequence):
|
|
|
|
assert (seq.request_id not in self._prefill) and (
|
|
|
|
seq.request_id not in self._decoding
|
|
|
|
), f"Sequence uid {seq.request_id} already exists."
|
|
|
|
self._prefill[seq.request_id] = seq
|
|
|
|
|
|
|
|
def extend(self, seqs: List[Sequence]):
|
|
|
|
for seq in seqs:
|
|
|
|
self._prefill[seq.request_id] = seq
|
|
|
|
|
|
|
|
def find_seq(self, request_id) -> Union[Sequence, None]:
|
|
|
|
seq = None
|
|
|
|
if request_id in self._decoding:
|
|
|
|
seq = self._decoding[request_id]
|
|
|
|
elif request_id in self._prefill:
|
|
|
|
seq = self._prefill[request_id]
|
|
|
|
return seq
|
|
|
|
|
|
|
|
def remove(self, seq: Sequence) -> None:
|
|
|
|
if seq.request_id in self._decoding:
|
|
|
|
self._decoding.pop(seq.request_id)
|
|
|
|
elif seq.request_id in self._prefill:
|
|
|
|
self._prefill.pop(seq.request_id)
|
2023-12-25 04:15:15 +00:00
|
|
|
else:
|
2024-02-19 09:18:20 +00:00
|
|
|
raise ValueError(f"Sequence {seq.request_id} is not in running list")
|
2023-12-25 04:15:15 +00:00
|
|
|
|
|
|
|
def ready_for_prefill(self):
|
2024-02-19 09:18:20 +00:00
|
|
|
if not self._decoding:
|
|
|
|
return len(self._prefill) > 0
|
|
|
|
return len(self._prefill) / len(self._decoding) >= self.prefill_ratio
|
2023-12-25 04:15:15 +00:00
|
|
|
|
|
|
|
def is_empty(self):
|
2024-02-19 09:18:20 +00:00
|
|
|
return not self._decoding and not self._prefill
|
2023-12-25 04:15:15 +00:00
|
|
|
|
2024-02-19 09:18:20 +00:00
|
|
|
def mark_prefill_running(self) -> None:
|
|
|
|
for seq_id in self._prefill:
|
|
|
|
self._prefill[seq_id].mark_running()
|
|
|
|
|
|
|
|
def move_prefill_to_decoding(self, seq_ids: List[int]) -> None:
|
|
|
|
for seq_id in seq_ids:
|
|
|
|
assert seq_id in self._prefill, f"Sequence {seq_id} is not in prefill list"
|
|
|
|
self._decoding[seq_id] = self._prefill.pop(seq_id)
|
2024-01-18 08:31:14 +00:00
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
class NaiveRequestHandler:
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self.running_list: List[DiffusionSequence] = []
|
|
|
|
self.waiting_list: List[str] = []
|
|
|
|
|
|
|
|
def _has_waiting(self) -> bool:
|
|
|
|
return any(lst for lst in self.waiting_list)
|
|
|
|
|
|
|
|
def _has_running(self) -> bool:
|
|
|
|
return any(lst for lst in self.running_list)
|
|
|
|
|
|
|
|
def check_unfinished_reqs(self):
|
|
|
|
return self._has_waiting() or self._has_running()
|
|
|
|
|
|
|
|
def add_sequence(self, seq: DiffusionSequence):
|
|
|
|
"""
|
|
|
|
Add the request to waiting list.
|
|
|
|
"""
|
|
|
|
assert not self._find_sequence(seq.request_id), f"Sequence {seq.request_id} already exists."
|
|
|
|
self.waiting_list.append(seq)
|
|
|
|
|
|
|
|
def _find_sequence(self, request_id: int) -> DiffusionSequence:
|
|
|
|
"""
|
|
|
|
Find the request by request_id.
|
|
|
|
"""
|
|
|
|
for lst in enumerate(self.waiting_list + self.running_list):
|
|
|
|
for seq in lst:
|
|
|
|
if seq.request_id == request_id:
|
|
|
|
return seq
|
|
|
|
return None
|
|
|
|
|
|
|
|
def schedule(self):
|
|
|
|
ret = None
|
|
|
|
if self._has_waiting:
|
|
|
|
ret = self.waiting_list[0]
|
|
|
|
self.waiting_list = self.waiting_list[1:]
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
class RequestHandler(NaiveRequestHandler):
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
|
|
|
RequestHandler is the core for handling existing requests and updating current batch.
|
|
|
|
During generation process, we call schedule function each iteration to update current batch.
|
|
|
|
|
|
|
|
Args:
|
2023-12-25 04:15:15 +00:00
|
|
|
inference_config: Configuration for initialize and manage kv cache.
|
|
|
|
model_config: Configuration for model
|
2024-01-26 06:00:10 +00:00
|
|
|
dtype (torch.dtype): The data type for weights and activations.
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
|
2023-12-18 02:40:47 +00:00
|
|
|
self.inference_config = inference_config
|
2023-12-25 04:15:15 +00:00
|
|
|
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
|
|
|
self.waiting_list: List[List] = [[], [], []]
|
|
|
|
self.done_list: List[Sequence] = []
|
2024-01-26 06:00:10 +00:00
|
|
|
self.dtype = inference_config.dtype
|
2024-01-17 08:03:10 +00:00
|
|
|
self.max_batch_size = inference_config.max_batch_size
|
2023-12-25 04:15:15 +00:00
|
|
|
|
2024-01-26 06:00:10 +00:00
|
|
|
# initialize cache
|
|
|
|
self._init_cache(model_config)
|
|
|
|
|
|
|
|
# initialize batch
|
|
|
|
device = torch.cuda.current_device()
|
|
|
|
kv_max_split_num = (
|
|
|
|
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
|
|
|
) // inference_config.block_size
|
|
|
|
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
|
|
|
|
|
|
|
fd_inter_tensor = FDIntermTensors()
|
2024-02-07 09:55:48 +00:00
|
|
|
|
|
|
|
if fd_inter_tensor._tensors_initialized:
|
|
|
|
fd_inter_tensor._reset()
|
|
|
|
|
2024-03-11 01:51:42 +00:00
|
|
|
# For Spec-Dec, process the speculated tokens plus the token in the last step for each seq
|
|
|
|
max_n_tokens = self.max_batch_size
|
|
|
|
max_n_tokens *= self.inference_config.max_n_spec_tokens + 1
|
|
|
|
|
2024-01-26 06:00:10 +00:00
|
|
|
fd_inter_tensor.initialize(
|
2024-03-11 01:51:42 +00:00
|
|
|
max_batch_size=max_n_tokens,
|
2024-04-18 08:56:46 +00:00
|
|
|
num_attn_heads=model_config.num_attention_heads // inference_config.tp_size,
|
2024-01-26 06:00:10 +00:00
|
|
|
kv_max_split_num=kv_max_split_num,
|
|
|
|
head_dim=head_dim,
|
|
|
|
dtype=self.dtype,
|
|
|
|
device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
|
|
|
# which may cause bugs and this issue should be fixed later.
|
2024-02-19 09:18:20 +00:00
|
|
|
self.running_bb = BatchBucket(
|
2024-04-18 08:56:46 +00:00
|
|
|
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
2024-01-26 06:00:10 +00:00
|
|
|
head_dim=head_dim,
|
|
|
|
max_batch_size=self.max_batch_size,
|
2024-02-19 09:18:20 +00:00
|
|
|
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
|
|
|
block_size=inference_config.block_size,
|
2024-01-26 06:00:10 +00:00
|
|
|
kv_max_split_num=kv_max_split_num,
|
2024-02-19 09:18:20 +00:00
|
|
|
fd_interm_tensor=fd_inter_tensor,
|
|
|
|
dtype=self.dtype,
|
|
|
|
device=device,
|
2024-06-05 02:51:19 +00:00
|
|
|
enable_streamingllm=inference_config.enable_streamingllm,
|
|
|
|
start_token_size=inference_config.start_token_size,
|
|
|
|
generated_token_size=inference_config.generated_token_size,
|
2024-02-19 09:18:20 +00:00
|
|
|
)
|
|
|
|
self.prefill_bb = BatchBucket(
|
2024-04-18 08:56:46 +00:00
|
|
|
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
2024-01-26 06:00:10 +00:00
|
|
|
head_dim=head_dim,
|
2024-02-19 09:18:20 +00:00
|
|
|
max_batch_size=self.max_batch_size,
|
|
|
|
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
|
|
|
block_size=inference_config.block_size,
|
|
|
|
kv_max_split_num=kv_max_split_num,
|
|
|
|
fd_interm_tensor=fd_inter_tensor,
|
2024-01-26 06:00:10 +00:00
|
|
|
dtype=self.dtype,
|
2024-02-19 09:18:20 +00:00
|
|
|
device=device,
|
2024-06-05 02:51:19 +00:00
|
|
|
enable_streamingllm=inference_config.enable_streamingllm,
|
|
|
|
start_token_size=inference_config.start_token_size,
|
|
|
|
generated_token_size=inference_config.generated_token_size,
|
2024-01-26 06:00:10 +00:00
|
|
|
)
|
|
|
|
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
def _has_running(self) -> bool:
|
|
|
|
return not self.running_bb.is_empty()
|
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
def _init_cache(self, model_config):
|
|
|
|
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
|
|
|
|
2023-12-26 13:34:27 +00:00
|
|
|
def get_kvcache(self):
|
|
|
|
return self.cache_manager.get_kv_cache()
|
|
|
|
|
2024-03-12 09:57:01 +00:00
|
|
|
def set_spec_dec_mode(self, n_spec_tokens: int):
|
|
|
|
self.prefill_bb.set_use_spec_dec(n_spec_tokens)
|
|
|
|
self.running_bb.set_use_spec_dec(n_spec_tokens)
|
|
|
|
|
|
|
|
def unset_spec_dec_mode(self):
|
|
|
|
self.prefill_bb.reset_use_spec_dec()
|
|
|
|
self.running_bb.reset_use_spec_dec()
|
|
|
|
|
2023-12-01 09:31:31 +00:00
|
|
|
def schedule(self):
|
|
|
|
"""
|
|
|
|
The main logic of request handler.
|
|
|
|
"""
|
2023-12-25 04:15:15 +00:00
|
|
|
if self._has_waiting():
|
|
|
|
# Try to allocate cache blocks for the sequence using a priority of prompt length.
|
|
|
|
for lst in reversed(self.waiting_list):
|
|
|
|
if lst:
|
2024-01-09 05:52:53 +00:00
|
|
|
remove_list = []
|
2023-12-25 04:15:15 +00:00
|
|
|
for seq in lst:
|
2023-12-26 13:34:27 +00:00
|
|
|
if seq.input_len > self.inference_config.max_input_len:
|
2023-12-25 04:15:15 +00:00
|
|
|
# If the prompt length is longer than max_input_len, abort the sequence.
|
2024-01-09 05:52:53 +00:00
|
|
|
logger.warning(
|
|
|
|
f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence."
|
|
|
|
)
|
2023-12-25 04:15:15 +00:00
|
|
|
self.abort_sequence(seq.request_id)
|
2024-01-18 08:31:14 +00:00
|
|
|
remove_list.append(seq)
|
2024-01-09 07:18:28 +00:00
|
|
|
break
|
2024-01-18 08:31:14 +00:00
|
|
|
|
2024-02-19 09:18:20 +00:00
|
|
|
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
2024-03-18 09:06:05 +00:00
|
|
|
# for now the recycle logic is not working
|
2024-02-19 09:18:20 +00:00
|
|
|
remove_list.extend(lst[:num_seqs_to_add])
|
|
|
|
self.running_list.extend(lst[:num_seqs_to_add])
|
2024-01-18 08:31:14 +00:00
|
|
|
|
2024-01-09 05:52:53 +00:00
|
|
|
for seq in remove_list:
|
|
|
|
lst.remove(seq)
|
2024-02-07 09:55:48 +00:00
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
if self.running_list.ready_for_prefill():
|
2024-03-12 09:57:01 +00:00
|
|
|
num_seqs_to_add = min(self.running_list.prefill_seq_num, self.prefill_bb.available_batch_size)
|
|
|
|
# overwrite the number of sequences to add to 1 if use_spec_dec is enabled
|
|
|
|
# TODO (zhaoyuanheng): support speculative decoding for batch size > 1
|
|
|
|
if self.prefill_bb.use_spec_dec:
|
|
|
|
num_seqs_to_add = 1
|
2023-12-25 04:15:15 +00:00
|
|
|
|
2024-02-19 09:18:20 +00:00
|
|
|
for seq in self.running_list.prefill[:num_seqs_to_add]:
|
|
|
|
seq.mark_running()
|
|
|
|
# allocate blocks for the prefill batch
|
|
|
|
self.prefill_bb.add_seqs(
|
|
|
|
self.running_list.prefill[:num_seqs_to_add],
|
|
|
|
alloc_block_tables_fn=self.cache_manager.allocate_context_from_block_tables,
|
|
|
|
)
|
|
|
|
|
|
|
|
return self.prefill_bb
|
|
|
|
|
|
|
|
if not self.running_bb.is_empty:
|
|
|
|
seqs_ids_to_recycle = self.cache_manager.allocate_tokens_from_block_tables(
|
|
|
|
self.running_bb.block_tables, self.running_bb.seq_lengths, self.running_bb.current_batch_size
|
|
|
|
)
|
|
|
|
if seqs_ids_to_recycle:
|
|
|
|
seqs_to_recycle = self.running_bb.pop_seqs(seqs_ids_to_recycle)
|
|
|
|
for seq in seqs_to_recycle:
|
2024-01-18 08:31:14 +00:00
|
|
|
seq.recycle()
|
|
|
|
self.running_list.remove(seq)
|
|
|
|
self.waiting_list[-1].append(seq)
|
|
|
|
# the recycled sequences are handled with highest priority.
|
2024-01-08 04:35:06 +00:00
|
|
|
|
2024-02-19 09:18:20 +00:00
|
|
|
return self.running_bb
|
2023-12-01 09:31:31 +00:00
|
|
|
|
2024-03-11 01:51:42 +00:00
|
|
|
def allocate_batch_spec_dec(self, batch: BatchBucket, n: int):
|
|
|
|
assert batch.use_spec_dec
|
|
|
|
if n > 0:
|
|
|
|
self.cache_manager.allocate_n_tokens_from_block_tables(
|
|
|
|
batch.block_tables, batch.seq_lengths, batch.current_batch_size, n=n
|
|
|
|
)
|
|
|
|
|
2023-12-25 04:15:15 +00:00
|
|
|
def add_sequence(self, req: Sequence):
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
|
|
|
Add the request to waiting list.
|
|
|
|
"""
|
2023-12-25 04:15:15 +00:00
|
|
|
assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists."
|
|
|
|
assert (
|
2024-01-10 02:38:53 +00:00
|
|
|
req.input_len <= self.inference_config.max_input_len
|
2023-12-25 04:15:15 +00:00
|
|
|
), f"Sequence {req.request_id} exceeds input length limit"
|
2024-01-10 02:38:53 +00:00
|
|
|
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
|
2023-12-25 04:15:15 +00:00
|
|
|
|
2024-03-01 06:47:36 +00:00
|
|
|
def abort_sequence(self, request_id: int):
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
2023-12-25 04:15:15 +00:00
|
|
|
Abort the request.
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
2024-03-01 06:47:36 +00:00
|
|
|
result = self._find_sequence(request_id)
|
|
|
|
if result is not None:
|
|
|
|
seq, priority = result
|
|
|
|
if seq.status == RequestStatus.WAITING:
|
|
|
|
seq.mark_aborted()
|
|
|
|
self.waiting_list[priority].remove(seq)
|
|
|
|
elif seq.status.is_running():
|
|
|
|
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
|
|
|
self.running_list.remove(seq)
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
self.done_list.remove(seq)
|
|
|
|
except:
|
|
|
|
return
|
|
|
|
return
|
2023-12-01 09:31:31 +00:00
|
|
|
|
2024-03-01 06:47:36 +00:00
|
|
|
def _find_sequence(self, request_id: int) -> Sequence:
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
2023-12-25 04:15:15 +00:00
|
|
|
Find the request by request_id.
|
2023-12-01 09:31:31 +00:00
|
|
|
"""
|
2023-12-25 04:15:15 +00:00
|
|
|
for priority, lst in enumerate(self.waiting_list):
|
|
|
|
for seq in lst:
|
|
|
|
if seq.request_id == request_id:
|
|
|
|
return seq, priority
|
|
|
|
|
|
|
|
if self.running_list.find_seq(request_id):
|
|
|
|
return seq, None
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
2024-03-11 01:51:42 +00:00
|
|
|
def update_seq_finished(self, sequence: Sequence, generation_config: GenerationConfig):
|
2023-12-25 04:15:15 +00:00
|
|
|
if (
|
2024-03-11 01:51:42 +00:00
|
|
|
sequence.output_token_id[-1] == generation_config.eos_token_id
|
|
|
|
or sequence.output_len >= generation_config.max_length
|
2023-12-25 04:15:15 +00:00
|
|
|
):
|
|
|
|
sequence.mark_finished()
|
2023-12-01 09:02:44 +00:00
|
|
|
|
2024-03-11 01:51:42 +00:00
|
|
|
def update_batch_finished(self, batch: BatchBucket, generation_config: GenerationConfig):
|
|
|
|
for seq in batch.seqs_li:
|
2024-04-23 05:09:55 +00:00
|
|
|
max_length = generation_config.max_length
|
|
|
|
max_new_tokens = generation_config.max_new_tokens
|
|
|
|
if max_length is not None:
|
|
|
|
max_new_tokens = max_length - seq.input_len
|
|
|
|
if seq.output_token_id[-1] == generation_config.eos_token_id or seq.output_len >= max_new_tokens:
|
2024-03-11 01:51:42 +00:00
|
|
|
seq.mark_finished()
|
|
|
|
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
def check_unfinished_reqs(self) -> bool:
|
2023-12-25 04:15:15 +00:00
|
|
|
return self._has_waiting() or not self.running_list.is_empty()
|
|
|
|
|
2024-05-08 07:14:06 +00:00
|
|
|
def total_requests_in_batch_bucket(self) -> int:
|
2024-03-01 06:47:36 +00:00
|
|
|
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
|
|
|
|
2024-03-11 01:51:42 +00:00
|
|
|
def append_next_tokens(self, sample_tokens: torch.Tensor):
|
|
|
|
assert sample_tokens.dim() == 1
|
|
|
|
n_elements = sample_tokens.size(0)
|
2024-02-19 09:18:20 +00:00
|
|
|
if not self.prefill_bb.is_empty:
|
2024-03-11 01:51:42 +00:00
|
|
|
assert (
|
|
|
|
self.prefill_bb.current_batch_size == n_elements
|
|
|
|
), f"Incompatible size: {n_elements} tokens to append while prefill batch size {self.prefill_bb.current_batch_size}"
|
2024-02-19 09:18:20 +00:00
|
|
|
self.prefill_bb.append_batch_tokens(sample_tokens)
|
2023-12-26 13:34:27 +00:00
|
|
|
else:
|
2024-03-11 01:51:42 +00:00
|
|
|
assert (
|
|
|
|
self.running_bb.current_batch_size == n_elements
|
|
|
|
), f"Incompatible size: {n_elements} tokens to append while running batch size {self.running_bb.current_batch_size}"
|
2024-02-19 09:18:20 +00:00
|
|
|
self.running_bb.append_batch_tokens(sample_tokens)
|
2023-12-18 02:40:47 +00:00
|
|
|
|
|
|
|
def update(self):
|
|
|
|
"""
|
2023-12-25 04:15:15 +00:00
|
|
|
Update current running list and done list
|
2023-12-18 02:40:47 +00:00
|
|
|
"""
|
2024-02-19 09:18:20 +00:00
|
|
|
if not self.prefill_bb.is_empty:
|
|
|
|
self.running_list.move_prefill_to_decoding(self.prefill_bb.seqs_ids)
|
|
|
|
self.running_bb.merge(self.prefill_bb)
|
|
|
|
# clear the prefill batch without assigning a free_block_tables_fn
|
|
|
|
# since we want to reuse the memory recorded on the block tables
|
|
|
|
self.prefill_bb.clear(free_block_tables_fn=None)
|
|
|
|
|
|
|
|
finished_seqs, _ = self.running_bb.pop_finished(self.cache_manager.free_block_table)
|
|
|
|
for seq in finished_seqs:
|
2024-01-02 05:02:20 +00:00
|
|
|
self.running_list.remove(seq)
|
2024-02-19 09:18:20 +00:00
|
|
|
self.done_list.extend(finished_seqs)
|
2024-01-02 05:02:20 +00:00
|
|
|
|
2024-02-19 09:18:20 +00:00
|
|
|
return finished_seqs
|
2024-05-14 02:00:55 +00:00
|
|
|
|
2024-06-05 02:51:19 +00:00
|
|
|
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
|
|
|
|
"""
|
|
|
|
Free the block that needs to be swapped out.
|
|
|
|
"""
|
|
|
|
self.cache_manager.streamingllm_free_block_tables(updated_block_ids)
|
|
|
|
|
2024-05-14 02:00:55 +00:00
|
|
|
|
|
|
|
class RPCRequestHandler(RequestHandler):
|
|
|
|
"""
|
|
|
|
RPC Version of request handler
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
|
|
|
|
self.inference_config = inference_config
|
|
|
|
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
|
|
|
self.waiting_list: List[List] = [[], [], []]
|
|
|
|
self.done_list: List[Sequence] = []
|
|
|
|
self.dtype = inference_config.dtype
|
|
|
|
self.max_batch_size = inference_config.max_batch_size
|
|
|
|
|
|
|
|
# initialize cache
|
|
|
|
self._init_cache(model_config)
|
|
|
|
|
|
|
|
# initialize batch
|
|
|
|
torch.cuda.current_device()
|
|
|
|
kv_max_split_num = (
|
|
|
|
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
|
|
|
) // inference_config.block_size
|
|
|
|
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
|
|
|
|
|
|
|
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
|
|
|
# which may cause bugs and this issue should be fixed later.
|
|
|
|
self.running_bb = BatchBucket(
|
|
|
|
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
|
|
|
head_dim=head_dim,
|
|
|
|
max_batch_size=self.max_batch_size,
|
|
|
|
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
|
|
|
block_size=inference_config.block_size,
|
|
|
|
kv_max_split_num=kv_max_split_num,
|
|
|
|
fd_interm_tensor=None,
|
|
|
|
dtype=self.dtype,
|
|
|
|
)
|
|
|
|
self.prefill_bb = BatchBucket(
|
|
|
|
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
|
|
|
head_dim=head_dim,
|
|
|
|
max_batch_size=self.max_batch_size,
|
|
|
|
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
|
|
|
block_size=inference_config.block_size,
|
|
|
|
kv_max_split_num=kv_max_split_num,
|
|
|
|
fd_interm_tensor=None,
|
|
|
|
dtype=self.dtype,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _init_cache(self, model_config):
|
|
|
|
self.cache_manager = RPCKVCacheManager(self.inference_config, model_config)
|