mirror of https://github.com/hpcaitech/ColossalAI
[Bug FIx] import llama context ops fix (#4524)
* added _vllm_rms_norm * change place * added tests * added tests * modify * adding kernels * added tests: * adding kernels * modify * added * updating kernels * adding tests * added tests * kernel change * submit * modify * added * edit comments * change name * change commnets and fix import * add * added * fix * add ops into init.py * addpull/4532/head
parent
2226c6836c
commit
e937461312
|
@ -1,7 +1,14 @@
|
||||||
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
|
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
|
||||||
|
from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
|
||||||
|
from .triton import softmax
|
||||||
|
from .triton import copy_kv_cache_to_dest
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LayerNorm",
|
"LayerNorm",
|
||||||
"FusedScaleMaskSoftmax",
|
"FusedScaleMaskSoftmax",
|
||||||
"MultiHeadAttention",
|
"MultiHeadAttention",
|
||||||
|
"llama_context_attn_fwd",
|
||||||
|
"bloom_context_attn_fwd",
|
||||||
|
"softmax",
|
||||||
|
"copy_kv_cache_to_dest",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .context_attention import llama_context_attn_fwd, bloom_context_attn_fwd
|
||||||
|
from .softmax import softmax
|
||||||
|
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
|
@ -9,8 +9,8 @@ from torch.nn import functional as F
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from tests.test_kernels.triton.utils import benchmark, torch_context_attention
|
from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention
|
||||||
from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
|
from colossalai.kernel.triton import bloom_context_attn_fwd
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
|
|
|
@ -9,8 +9,8 @@ from torch.nn import functional as F
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from tests.test_kernels.triton.utils import benchmark, torch_context_attention
|
from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention
|
||||||
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
|
from colossalai.kernel.triton import llama_context_attn_fwd
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
|
|
Loading…
Reference in New Issue