register meta func for rnn (#2159)

pull/2172/head
Zihao 2022-12-21 23:06:18 +08:00 committed by GitHub
parent cfe2a9bd90
commit 12e7bcd720
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 48 additions and 11 deletions

View File

@ -200,19 +200,56 @@ def meta_adaptive_avg_pool2d_backward(
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn.default)
def meta_cuda_rnn(
input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
weight_buf: torch.Tensor,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs,
input,
weight,
weight_stride0,
weight_buf,
hx,
cx,
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
dropout,
train,
bidirectional,
batch_sizes,
dropout_state,
):
if cx is not None:
return torch.empty_like(input), torch.empty_like(hx), torch.empty_like(cx)
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
mini_batch = batch_sizes[0]
batch_sizes_sum = input.shape[0]
else:
return torch.empty_like(input), torch.empty_like(hx), torch.empty((), device='meta')
seq_length = input.shape[1] if batch_first else input.shape[0]
mini_batch = input.shape[0] if batch_first else input.shape[1]
batch_sizes_sum = -1
num_directions = 2 if bidirectional else 1
out_size = proj_size if proj_size != 0 else hidden_size
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = (
[mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape)
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
reserve_shape = 0 if train else 0
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
return output, hy, cy, reserve, weight_buf
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp