mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Add fused rotary kernel and get cos cache kernel (#5302)
* add fused rotary and get cos cache func * staged * fix bugs * fix bugspull/5311/head
parent
3da9993b0d
commit
c647e00e3c
|
@ -11,11 +11,12 @@ if HAS_TRITON:
|
|||
from .context_attn_unpad import context_attention_unpadded
|
||||
from .flash_decoding import flash_decoding_attention
|
||||
from .flash_decoding_utils import FDIntermTensors
|
||||
|
||||
from .rms_layernorm import rms_layernorm
|
||||
from .fused_rotary_embedding import fused_rotary_embedding
|
||||
from .gptq_triton import gptq_fused_linear_triton
|
||||
from .kvcache_copy import copy_kv_to_blocked_cache
|
||||
from .no_pad_rotary_embedding import rotary_embedding
|
||||
from .rms_layernorm import rms_layernorm
|
||||
from .rotary_cache_copy import get_xine_cache
|
||||
from .softmax import softmax
|
||||
|
||||
__all__ = [
|
||||
|
@ -27,4 +28,6 @@ if HAS_TRITON:
|
|||
"gptq_fused_linear_triton",
|
||||
"rotary_embedding",
|
||||
"FDIntermTensors",
|
||||
"fused_rotary_embedding",
|
||||
"get_xine_cache",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_rotary_emb(
|
||||
q,
|
||||
k,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
cumsum_lengths,
|
||||
q_token_stride,
|
||||
q_head_stride,
|
||||
k_token_stride,
|
||||
k_head_stride,
|
||||
head_dim_stride,
|
||||
cos_token_stride,
|
||||
cos_dim_stride,
|
||||
q_total_tokens,
|
||||
Q_HEAD_NUM: tl.constexpr,
|
||||
K_HEAD_NUM: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_HEAD: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
N_ELEMENTS: tl.constexpr,
|
||||
):
|
||||
block_head_index = tl.program_id(0)
|
||||
block_group_index = tl.program_id(1)
|
||||
group_token_index = tl.program_id(2)
|
||||
idx = block_group_index * BLOCK_SIZE + group_token_index
|
||||
|
||||
# original seq_idx and pos
|
||||
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
|
||||
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
|
||||
cos = tl.load(
|
||||
cos_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride
|
||||
) # [1,HEAD_DIM//2]
|
||||
sin = tl.load(sin_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride)
|
||||
|
||||
cur_head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
||||
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
||||
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
||||
|
||||
off_q0 = (
|
||||
idx * q_token_stride
|
||||
+ cur_head_range[None, :, None] * q_head_stride
|
||||
+ dim_range0[None, None, :] * head_dim_stride
|
||||
)
|
||||
off_q1 = (
|
||||
idx * q_token_stride
|
||||
+ cur_head_range[None, :, None] * q_head_stride
|
||||
+ dim_range1[None, None, :] * head_dim_stride
|
||||
)
|
||||
|
||||
off_k0 = (
|
||||
idx * k_token_stride
|
||||
+ cur_head_range[None, :, None] * k_head_stride
|
||||
+ dim_range0[None, None, :] * head_dim_stride
|
||||
)
|
||||
off_k1 = (
|
||||
idx * q_token_stride
|
||||
+ cur_head_range[None, :, None] * k_head_stride
|
||||
+ dim_range1[None, None, :] * head_dim_stride
|
||||
)
|
||||
|
||||
q_0 = tl.load(
|
||||
q + off_q0,
|
||||
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
q_1 = tl.load(
|
||||
q + off_q1,
|
||||
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_0 = tl.load(
|
||||
k + off_k0,
|
||||
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_1 = tl.load(
|
||||
k + off_k1,
|
||||
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
out_q0 = q_0 * cos - q_1 * sin
|
||||
out_q1 = k_0 * sin + k_1 * cos
|
||||
|
||||
out_k0 = q_0 * cos - q_1 * sin
|
||||
out_k1 = k_0 * sin + k_1 * cos
|
||||
# concat
|
||||
tl.store(
|
||||
q + off_q0,
|
||||
out_q0,
|
||||
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
|
||||
)
|
||||
tl.store(
|
||||
q + off_q1,
|
||||
out_q1,
|
||||
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
|
||||
)
|
||||
|
||||
tl.store(
|
||||
k + off_k0,
|
||||
out_k0,
|
||||
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
|
||||
)
|
||||
tl.store(
|
||||
k + off_k1,
|
||||
out_k1,
|
||||
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def fused_rotary_embedding(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
lengths,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q: query tensor, [total_tokens, head_num, head_dim]
|
||||
k: key tensor, [total_tokens, head_num, head_dim]
|
||||
cos: cosine for rotary embedding, [max_position_len, head_dim]
|
||||
sin: sine for rotary embedding, [max_position_len, head_dim]
|
||||
lengths [num_seqs]
|
||||
"""
|
||||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.size(0) == k.size(0)
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_SIZE = 16
|
||||
cumsum_lens = torch.cumsum(lengths, dim=0)
|
||||
|
||||
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE)
|
||||
|
||||
if head_dim >= 128:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
q_token_stride = q.stride(0)
|
||||
q_head_stride = q.stride(1)
|
||||
head_dim_stride = q.stride(2)
|
||||
|
||||
k_token_stride = k.stride(0)
|
||||
k_head_stride = k.stride(1)
|
||||
|
||||
k_head_num = q.shape[1]
|
||||
|
||||
cos_token_stride = cos.stride(0)
|
||||
cos_dim_stride = cos.stride(1)
|
||||
|
||||
fused_rotary_emb[grid](
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
cumsum_lens,
|
||||
q_token_stride,
|
||||
q_head_stride,
|
||||
k_token_stride,
|
||||
k_head_stride,
|
||||
head_dim_stride,
|
||||
cos_token_stride,
|
||||
cos_dim_stride,
|
||||
q_total_tokens,
|
||||
Q_HEAD_NUM=q_head_num,
|
||||
K_HEAD_NUM=k_head_num,
|
||||
HEAD_DIM=head_dim,
|
||||
BLOCK_HEAD=BLOCK_HEAD,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
N_ELEMENTS=triton.next_power_of_2(q_total_tokens),
|
||||
num_warps=num_warps,
|
||||
)
|
|
@ -98,11 +98,12 @@ def rotary_embedding(
|
|||
Args:
|
||||
q: query tensor, [total_tokens, head_num, head_dim]
|
||||
k: key tensor, [total_tokens, head_num, head_dim]
|
||||
cos: cosine for rotary embedding, [total_tokens, head_dim]
|
||||
sin: sine for rotary embedding, [total_tokens, head_dim]
|
||||
cos: cosine for rotary embedding, [max_position_len, head_dim]
|
||||
sin: sine for rotary embedding, [max_position_len, head_dim]
|
||||
lengths [num_seqs]
|
||||
"""
|
||||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
||||
assert q.size(0) == k.size(0)
|
||||
BLOCK_HEAD = 4
|
||||
BLOCK_TOKENS = 8
|
||||
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def prefill_cache_kernel(
|
||||
CaChe,
|
||||
cumsum_lengths,
|
||||
output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
total_length,
|
||||
HIDDEN_DIM: tl.constexpr,
|
||||
N_ELEMENTS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
idx0 = tl.program_id(axis=0)
|
||||
idx1 = tl.program_id(axis=1)
|
||||
idx = idx0 * BLOCK_SIZE + idx1
|
||||
|
||||
# original seq_idx and pos
|
||||
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
|
||||
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
|
||||
_cache = tl.load(CaChe + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride)
|
||||
tl.store(output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, _cache, mask=idx < total_length)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def decoding_cache_kernel(
|
||||
CaChe,
|
||||
lengths,
|
||||
output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
HIDDEN_DIM: tl.constexpr,
|
||||
NUM_SEQS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,]
|
||||
_cache = tl.load(CaChe + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride)
|
||||
tl.store(
|
||||
output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
|
||||
_cache,
|
||||
mask=idx[:, None] < NUM_SEQS,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool = False):
|
||||
"""
|
||||
Transform cos/sin cache into no pad sequence, with two different modes.
|
||||
Args:
|
||||
lengths: shape(num_seqs,), stores lenghth of each sequence.
|
||||
cache: shape(max_rotary_position(e.g.2048), head_dim), cos/sin cache constrcuted in model.
|
||||
is_prompts: bool, mark if in prefill mode.
|
||||
For prefill mode:
|
||||
cos/sin cache for each sequence is equal to its length.
|
||||
For decoding mode:
|
||||
cos/sin cache is only needed for the last token.
|
||||
"""
|
||||
|
||||
_, hidden_dim = cache.shape
|
||||
num_seqs = lengths.numel()
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
if hidden_dim >= 128:
|
||||
num_warps = 8
|
||||
else:
|
||||
num_warps = 4
|
||||
|
||||
cache_stride = cache.stride(0)
|
||||
hidden_stride = cache.stride(1)
|
||||
|
||||
if is_prompts:
|
||||
total_length = lengths.sum().item()
|
||||
cumsum_lens = torch.cumsum(lengths, dim=0)
|
||||
output = torch.empty((total_length, hidden_dim), dtype=cache.dtype, device=cache.device)
|
||||
grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE)
|
||||
prefill_cache_kernel[grid](
|
||||
cache,
|
||||
cumsum_lens,
|
||||
output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
total_length,
|
||||
HIDDEN_DIM=hidden_dim,
|
||||
N_ELEMENTS=triton.next_power_of_2(num_seqs),
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
else:
|
||||
# BUG: get memory access error whe using a deepcopy lengths to replace lengths
|
||||
nlengths = torch.as_tensor(lengths) - 1
|
||||
output = torch.empty((num_seqs, hidden_dim), dtype=cache.dtype, device=cache.device)
|
||||
grid = (triton.cdiv(num_seqs, BLOCK_SIZE),)
|
||||
decoding_cache_kernel[grid](
|
||||
cache,
|
||||
nlengths,
|
||||
output,
|
||||
cache_stride,
|
||||
hidden_stride,
|
||||
HIDDEN_DIM=hidden_dim,
|
||||
NUM_SEQS=num_seqs,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
return output
|
|
@ -0,0 +1,93 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
|
||||
from colossalai.kernel.triton.no_pad_rotary_embedding import rotary_embedding
|
||||
from colossalai.kernel.triton.rotary_cache_copy import get_xine_cache
|
||||
|
||||
BATCH = 16
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[2**i for i in range(4, 12)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
||||
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||
args={"num_kv_heads": 16},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def torch_rotary_emb(x, cos, sin):
|
||||
seq_len, h, dim = x.shape
|
||||
x0 = x[:, :, 0 : dim // 2]
|
||||
x1 = x[:, :, dim // 2 : dim]
|
||||
cos = cos.view((seq_len, 1, dim // 2))
|
||||
sin = sin.view((seq_len, 1, dim // 2))
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
return torch.cat((o0, o1), dim=-1)
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_rotary_emb(
|
||||
provider: str,
|
||||
num_tokens: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
warmup = 10
|
||||
rep = 100
|
||||
|
||||
head_dim = 128
|
||||
dtype = torch.float16
|
||||
q_shape = (num_tokens, num_kv_heads, head_dim)
|
||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||
k_shape = (num_tokens, num_kv_heads, head_dim)
|
||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||
cos_shape = (4096, head_dim // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
|
||||
if provider == "torch_rotary_emb_func":
|
||||
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
|
||||
elif provider == "triton_rotary_emb_func":
|
||||
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
|
||||
else:
|
||||
raise ValueError("Undefined provider")
|
||||
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_tokens = 20
|
||||
num_kv_heads = 32
|
||||
head_dim = 64
|
||||
dtype = torch.float32
|
||||
q_shape = (num_tokens, num_kv_heads, head_dim)
|
||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||
q_copy = deepcopy(q)
|
||||
|
||||
k_shape = (num_tokens, num_kv_heads, head_dim)
|
||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||
k_copy = deepcopy(k)
|
||||
|
||||
cos_shape = (1024, head_dim)
|
||||
lengths = torch.tensor([3, 4, 6, 7], device="cuda")
|
||||
cos_cache = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin_cache = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
|
||||
cos = get_xine_cache(lengths, cos_cache[:, : head_dim // 2])
|
||||
sin = get_xine_cache(lengths, sin_cache[:, : head_dim // 2])
|
||||
|
||||
rotary_embedding(q, k, cos, sin)
|
||||
fused_rotary_embedding(q_copy, k_copy, cos_cache, sin_cache, lengths)
|
||||
torch.allclose(q, q_copy)
|
||||
torch.allclose(k, k_copy)
|
||||
|
||||
# benchmark_rotary_emb.run(save_path=".",print_data=True)
|
|
@ -0,0 +1,83 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.inference.modeling.models.llama import get_cos_sin
|
||||
from colossalai.kernel.triton import get_xine_cache
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
|
||||
@pytest.mark.parametrize("HEAD_DIM", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
|
||||
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
|
||||
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
|
||||
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda")
|
||||
# prefill
|
||||
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=True, dtype=dtype)
|
||||
cos = get_xine_cache(lengths, cos_cache, is_prompts=True)
|
||||
assert torch.allclose(cos, cos_ref)
|
||||
# decoding
|
||||
ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=False, dtype=dtype)
|
||||
cos = get_xine_cache(lengths, cos_cache, is_prompts=False)
|
||||
assert torch.allclose(cos, ncos_ref)
|
||||
|
||||
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["max_num_tokens"],
|
||||
x_vals=[2**i for i in range(6, 12)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_get_cos_sin_func", "triton_get_xine_func"],
|
||||
line_names=["torch_get_cos_sin_func", "triton_get_xine_func"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="Get_cos-sin_func",
|
||||
args={"batch_size": 16, "head_dim": 256},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_get_xine_cache(
|
||||
provider: str,
|
||||
max_num_tokens: int,
|
||||
batch_size: int,
|
||||
head_dim: int,
|
||||
):
|
||||
warmup = 10
|
||||
rep = 1000
|
||||
max_token_per_seq = max_num_tokens // batch_size
|
||||
dtype = torch.float16
|
||||
cos_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda")
|
||||
sin_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda")
|
||||
lengths = torch.randint(2, max_token_per_seq, (batch_size,), device="cuda")
|
||||
|
||||
if provider == "torch_get_cos_sin_func":
|
||||
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
||||
elif provider == "triton_get_xine_func":
|
||||
fn = lambda: [
|
||||
get_xine_cache(lengths, cos_cache, is_prompts=False),
|
||||
get_xine_cache(lengths, sin_cache, is_prompts=False),
|
||||
]
|
||||
else:
|
||||
raise ValueError("Undefined provider")
|
||||
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_get_xine_cache(4, 64, 256, torch.float32)
|
||||
# benchmark_get_xine_cache.run(save_path=".",print_data=True)
|
Loading…
Reference in New Issue