mirror of https://github.com/hpcaitech/ColossalAI
[feat] add flash attention (#1762)
parent
0584654c79
commit
25952b67d7
|
@ -0,0 +1,331 @@
|
|||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError:
|
||||
raise ImportError('please install triton from https://github.com/openai/triton')
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
except ImportError:
|
||||
raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
|
||||
|
||||
|
||||
def triton_check():
|
||||
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
|
||||
cuda_version = subprocess.check_output([os.path.join(cuda_home, "bin/nvcc"), "--version"]).decode().strip()
|
||||
cuda_version = cuda_version.split('release ')[1]
|
||||
cuda_version = cuda_version.split(',')[0]
|
||||
cuda_version = cuda_version.split('.')
|
||||
if len(cuda_version) == 2 and \
|
||||
(int(cuda_version[0]) == 11 and int(cuda_version[1]) >= 4) or \
|
||||
int(cuda_version[0]) > 11:
|
||||
return True
|
||||
return False
|
||||
|
||||
TRITON_AVALIABLE = triton_check()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hz * N_CTX + offs_m
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
tl.store(t_ptrs, acc_scale)
|
||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
||||
p = p.to(tl.float16)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(tl.float16), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v, trans_b=True)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds.to(tl.float16), q, trans_a=True)
|
||||
# # compute dq
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
# # increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
class _TritonFlashAttention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.BLOCK = BLOCK
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
|
||||
# NOTE: kernel currently buggy for other values of `num_warps`
|
||||
num_warps = 8
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
def triton_flash_attention(q, k, v, sm_scale):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch, nheads, seq, headdim)
|
||||
k: (batch, nheads, seq, headdim)
|
||||
v: (batch, nheads, seq, headdim)
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Return:
|
||||
out: (batch, nheads, seq, headdim)
|
||||
"""
|
||||
if TRITON_AVALIABLE:
|
||||
return _TritonFlashAttention.apply(q, k, v, sm_scale)
|
||||
else:
|
||||
raise RuntimeError("Triton kernel requires CUDA 11.4+!")
|
||||
|
||||
|
||||
def flash_attention(q, k, v, sm_scale, batch_size, seq_len, dropout_p=0., causal=True):
|
||||
"""
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
batch_size: int.
|
||||
seq_len: int.
|
||||
dropout_p: float. Dropout probability.
|
||||
sm_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
"""
|
||||
lengths = torch.full((batch_size,), fill_value=seq_len, device=q.device)
|
||||
cu_seqlens = torch.zeros((batch_size + 1,), device=q.device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
return flash_attn_unpadded_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len,
|
||||
dropout_p=dropout_p, softmax_scale=sm_scale, causal=causal)
|
|
@ -7,3 +7,6 @@ titans
|
|||
torchaudio
|
||||
torchrec
|
||||
contexttimer
|
||||
einops
|
||||
triton==2.0.0.dev20221011
|
||||
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
|
|
@ -0,0 +1,82 @@
|
|||
import torch
|
||||
import pytest
|
||||
from einops import rearrange
|
||||
from colossalai.kernel.cuda_native.flash_attention import flash_attention, triton_flash_attention, TRITON_AVALIABLE
|
||||
|
||||
|
||||
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
ref_out = torch.matmul(p, v)
|
||||
return ref_out
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||||
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
dout = torch.randn_like(q)
|
||||
|
||||
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
|
||||
# triton implementation
|
||||
if TRITON_AVALIABLE:
|
||||
tri_out = triton_flash_attention(q, k, v, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
||||
else:
|
||||
try:
|
||||
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Error type not match!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
|
||||
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
dout = torch.randn_like(q)
|
||||
|
||||
# reference implementation
|
||||
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
|
||||
# flash implementation
|
||||
q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
|
||||
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
|
||||
dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
|
||||
tri_out.backward(dout, retain_graph=True)
|
||||
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
|
||||
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), (tri_out, tri_dq, tri_dk, tri_dv))
|
||||
|
||||
# compare
|
||||
assert torch.allclose(ref_out, tri_out, atol=1e-3)
|
||||
assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
|
||||
assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
|
||||
assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
|
Loading…
Reference in New Issue