mirror of https://github.com/hpcaitech/ColossalAI
register meta func for rnn (#2159)
parent
cfe2a9bd90
commit
12e7bcd720
|
@ -200,19 +200,56 @@ def meta_adaptive_avg_pool2d_backward(
|
|||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn.default)
|
||||
def meta_cuda_rnn(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
weight_buf: torch.Tensor,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
input,
|
||||
weight,
|
||||
weight_stride0,
|
||||
weight_buf,
|
||||
hx,
|
||||
cx,
|
||||
mode,
|
||||
hidden_size,
|
||||
proj_size,
|
||||
num_layers,
|
||||
batch_first,
|
||||
dropout,
|
||||
train,
|
||||
bidirectional,
|
||||
batch_sizes,
|
||||
dropout_state,
|
||||
):
|
||||
if cx is not None:
|
||||
return torch.empty_like(input), torch.empty_like(hx), torch.empty_like(cx)
|
||||
|
||||
is_input_packed = len(batch_sizes) != 0
|
||||
if is_input_packed:
|
||||
seq_length = len(batch_sizes)
|
||||
mini_batch = batch_sizes[0]
|
||||
batch_sizes_sum = input.shape[0]
|
||||
else:
|
||||
return torch.empty_like(input), torch.empty_like(hx), torch.empty((), device='meta')
|
||||
seq_length = input.shape[1] if batch_first else input.shape[0]
|
||||
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
||||
batch_sizes_sum = -1
|
||||
|
||||
num_directions = 2 if bidirectional else 1
|
||||
out_size = proj_size if proj_size != 0 else hidden_size
|
||||
if is_input_packed:
|
||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||
else:
|
||||
out_shape = (
|
||||
[mini_batch, seq_length, out_size * num_directions]
|
||||
if batch_first
|
||||
else [seq_length, mini_batch, out_size * num_directions]
|
||||
)
|
||||
output = input.new_empty(out_shape)
|
||||
|
||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||
cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape)
|
||||
|
||||
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
|
||||
|
||||
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
|
||||
reserve_shape = 0 if train else 0
|
||||
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
|
||||
|
||||
return output, hy, cy, reserve, weight_buf
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
|
|
Loading…
Reference in New Issue