diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_registrations.py index d614219db..8c0201c71 100644 --- a/colossalai/fx/_meta_registrations.py +++ b/colossalai/fx/_meta_registrations.py @@ -200,19 +200,56 @@ def meta_adaptive_avg_pool2d_backward( # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp @register_meta(aten._cudnn_rnn.default) def meta_cuda_rnn( - input: torch.Tensor, - weight: torch.Tensor, - weight_stride0: int, - weight_buf: torch.Tensor, - hx: torch.Tensor, - cx: Optional[torch.Tensor] = None, - *args, - **kwargs, + input, + weight, + weight_stride0, + weight_buf, + hx, + cx, + mode, + hidden_size, + proj_size, + num_layers, + batch_first, + dropout, + train, + bidirectional, + batch_sizes, + dropout_state, ): - if cx is not None: - return torch.empty_like(input), torch.empty_like(hx), torch.empty_like(cx) + + is_input_packed = len(batch_sizes) != 0 + if is_input_packed: + seq_length = len(batch_sizes) + mini_batch = batch_sizes[0] + batch_sizes_sum = input.shape[0] else: - return torch.empty_like(input), torch.empty_like(hx), torch.empty((), device='meta') + seq_length = input.shape[1] if batch_first else input.shape[0] + mini_batch = input.shape[0] if batch_first else input.shape[1] + batch_sizes_sum = -1 + + num_directions = 2 if bidirectional else 1 + out_size = proj_size if proj_size != 0 else hidden_size + if is_input_packed: + out_shape = [batch_sizes_sum, out_size * num_directions] + else: + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) + output = input.new_empty(out_shape) + + cell_shape = [num_layers * num_directions, mini_batch, hidden_size] + cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape) + + hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) + + # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) + reserve_shape = 0 if train else 0 + reserve = input.new_empty(reserve_shape, dtype=torch.uint8) + + return output, hy, cy, reserve, weight_buf # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp