mirror of https://github.com/hpcaitech/ColossalAI
register meta func for rnn (#2159)
parent
cfe2a9bd90
commit
12e7bcd720
|
@ -200,19 +200,56 @@ def meta_adaptive_avg_pool2d_backward(
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||||
@register_meta(aten._cudnn_rnn.default)
|
@register_meta(aten._cudnn_rnn.default)
|
||||||
def meta_cuda_rnn(
|
def meta_cuda_rnn(
|
||||||
input: torch.Tensor,
|
input,
|
||||||
weight: torch.Tensor,
|
weight,
|
||||||
weight_stride0: int,
|
weight_stride0,
|
||||||
weight_buf: torch.Tensor,
|
weight_buf,
|
||||||
hx: torch.Tensor,
|
hx,
|
||||||
cx: Optional[torch.Tensor] = None,
|
cx,
|
||||||
*args,
|
mode,
|
||||||
**kwargs,
|
hidden_size,
|
||||||
|
proj_size,
|
||||||
|
num_layers,
|
||||||
|
batch_first,
|
||||||
|
dropout,
|
||||||
|
train,
|
||||||
|
bidirectional,
|
||||||
|
batch_sizes,
|
||||||
|
dropout_state,
|
||||||
):
|
):
|
||||||
if cx is not None:
|
|
||||||
return torch.empty_like(input), torch.empty_like(hx), torch.empty_like(cx)
|
is_input_packed = len(batch_sizes) != 0
|
||||||
|
if is_input_packed:
|
||||||
|
seq_length = len(batch_sizes)
|
||||||
|
mini_batch = batch_sizes[0]
|
||||||
|
batch_sizes_sum = input.shape[0]
|
||||||
else:
|
else:
|
||||||
return torch.empty_like(input), torch.empty_like(hx), torch.empty((), device='meta')
|
seq_length = input.shape[1] if batch_first else input.shape[0]
|
||||||
|
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
||||||
|
batch_sizes_sum = -1
|
||||||
|
|
||||||
|
num_directions = 2 if bidirectional else 1
|
||||||
|
out_size = proj_size if proj_size != 0 else hidden_size
|
||||||
|
if is_input_packed:
|
||||||
|
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||||
|
else:
|
||||||
|
out_shape = (
|
||||||
|
[mini_batch, seq_length, out_size * num_directions]
|
||||||
|
if batch_first
|
||||||
|
else [seq_length, mini_batch, out_size * num_directions]
|
||||||
|
)
|
||||||
|
output = input.new_empty(out_shape)
|
||||||
|
|
||||||
|
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||||
|
cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape)
|
||||||
|
|
||||||
|
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
|
||||||
|
|
||||||
|
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
|
||||||
|
reserve_shape = 0 if train else 0
|
||||||
|
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
|
||||||
|
|
||||||
|
return output, hy, cy, reserve, weight_buf
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||||
|
|
Loading…
Reference in New Issue