# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py # should be activated for PyTorch version 1.12.0 and below # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # for more meta_registrations from typing import Callable, List, Optional, Tuple, Union import torch from torch.utils._pytree import tree_map aten = torch.ops.aten meta_lib = torch.library.Library("aten", "IMPL", "Meta") meta_table = {} def register_meta(op, register_dispatcher=True): def wrapper(f): def add_func(op): meta_table[op] = f if register_dispatcher: name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) try: meta_lib.impl(name, f) except: pass tree_map(add_func, op) return f return wrapper # ============================== Convolutions ====================================== # https://github.com/pytorch/pytorch/pull/79834 @register_meta(aten.convolution.default) def meta_conv( input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, ): def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html Args: ln: length of the dimension p: padding in that dim d: dilation in that dim k: kernel size in that dim s: stride in that dim Returns: The output length """ return (ln + 2 * p - d * (k - 1) - 1) // s + 1 def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: """ Formula to apply to calculate the length of some dimension of the output if transposed convolution is used. See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html Args: ln: length of the dimension p: padding in that dim d: dilation in that dim k: kernel size in that dim s: stride in that dim op: output padding in that dim Returns: The output length """ return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 def calc_conv_nd_return_shape( dims: torch.Size, kernel_size: torch.Size, stride: Union[List[int], int], padding: Union[List[int], int], dilation: Union[List[int], int], output_padding: Optional[Union[List[int], int]] = None, ): ret_shape = [] if isinstance(stride, int): stride = [stride] * len(dims) elif len(stride) == 1: stride = [stride[0]] * len(dims) if isinstance(padding, int): padding = [padding] * len(dims) elif len(padding) == 1: padding = [padding[0]] * len(dims) if isinstance(dilation, int): dilation = [dilation] * len(dims) elif len(dilation) == 1: dilation = [dilation[0]] * len(dims) output_padding_list: Optional[List[int]] = None if output_padding: if isinstance(output_padding, int): output_padding_list = [output_padding] * len(dims) elif len(output_padding) == 1: output_padding_list = [output_padding[0]] * len(dims) else: output_padding_list = output_padding for i in range(len(dims)): # If output_padding is present, we are dealing with a transposed convolution if output_padding_list: ret_shape.append( _formula_transposed( dims[i], padding[i], dilation[i], kernel_size[i], stride[i], output_padding_list[i], )) else: ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) return ret_shape def pick_memory_format(): if input_tensor.is_contiguous(memory_format=torch.channels_last): return torch.channels_last elif input_tensor.is_contiguous(memory_format=torch.contiguous_format): return torch.contiguous_format elif input_tensor.is_contiguous(memory_format=torch.preserve_format): return torch.preserve_format kernel_size = weight.shape[2:] dims = input_tensor.shape[2:] if is_transposed: out_channels = groups * weight.shape[1] shape_out = calc_conv_nd_return_shape( dims, kernel_size, stride, padding, dilation, output_padding, ) else: out_channels = weight.shape[0] if weight.shape[1] != input_tensor.shape[1] / groups: raise RuntimeError("Invalid channel dimensions") shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) mem_fmt = pick_memory_format() out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out @register_meta(aten._convolution.default) def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, *extra_args): out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) return out @register_meta(aten.convolution_backward.default) def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask): return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp @register_meta(aten._adaptive_avg_pool2d_backward.default) def meta_adaptive_avg_pool2d_backward( grad_output: torch.Tensor, input: torch.Tensor, ): grad_input = torch.empty_like(input) return grad_input # ================================ RNN ============================================= # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp @register_meta(aten._cudnn_rnn.default) def meta_cuda_rnn( input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, ): is_input_packed = len(batch_sizes) != 0 if is_input_packed: seq_length = len(batch_sizes) mini_batch = batch_sizes[0] batch_sizes_sum = input.shape[0] else: seq_length = input.shape[1] if batch_first else input.shape[0] mini_batch = input.shape[0] if batch_first else input.shape[1] batch_sizes_sum = -1 num_directions = 2 if bidirectional else 1 out_size = proj_size if proj_size != 0 else hidden_size if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: out_shape = ([mini_batch, seq_length, out_size * num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape) hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) reserve_shape = 0 if train else 0 reserve = input.new_empty(reserve_shape, dtype=torch.uint8) return output, hy, cy, reserve, weight_buf # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp @register_meta(aten._cudnn_rnn_backward.default) def meta_cudnn_rnn_backward(input: torch.Tensor, weight: torch.Tensor, weight_stride0: int, hx: torch.Tensor, cx: Optional[torch.Tensor] = None, *args, **kwargs): print(input, weight, hx, cx) grad_input = torch.empty_like(input) grad_weight = torch.empty_like(weight) grad_hx = torch.empty_like(hx) grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') return grad_input, grad_weight, grad_hx, grad_cx # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp # ============================== Activations ======================================= @register_meta(aten.relu.default) def meta_relu(input: torch.Tensor): return torch.empty_like(input) @register_meta(aten.prelu.default) def meta_prelu(input: torch.Tensor, weight: torch.Tensor): return torch.empty_like(input) @register_meta(aten.hardswish.default) def meta_hardswish(input: torch.Tensor): return torch.empty_like(input) @register_meta(aten.hardtanh.default) def meta_hardtanh(input: torch.Tensor, min, max): return torch.empty_like(input) @register_meta(aten.hardswish_backward.default) def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): grad_in = torch.empty_like(input) return grad_in @register_meta(aten.hardtanh_backward.default) def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int): grad_in = torch.empty_like(input) return grad_in # ============================== Normalization ===================================== # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm.default) def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): n_input = input.size(1) output = torch.empty_like(input) running_mean = torch.empty((n_input), device='meta') running_var = torch.empty((n_input), device='meta') return output, running_mean, running_var # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm_backward.default) def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) return dX, dgamma, dbeta # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.cudnn_batch_norm.default) def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): n_input = input.size(1) output = torch.empty_like(input) running_mean = torch.empty((n_input), device='meta') running_var = torch.empty((n_input), device='meta') reserve = torch.empty((0), dtype=torch.uint8, device='meta') return output, running_mean, running_var, reserve # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp # NB: CuDNN only implements the backward algorithm for batchnorm # in training mode (evaluation mode batchnorm has a different algorithm), # which is why this doesn't accept a 'training' parameter. @register_meta(aten.cudnn_batch_norm_backward.default) def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, save_invstd, eps, reserve): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) return dX, dgamma, dbeta # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm.default) def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): bs = input.size(0) n_input = input.size(1) output = torch.empty_like(input) running_mean = torch.empty((bs, n_input, 1), device='meta') running_var = torch.empty((bs, n_input, 1), device='meta') return output, running_mean, running_var # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm_backward.default) def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(bias) return dX, dgamma, dbeta # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp @register_meta(aten.native_group_norm_backward.default) def meta_gn_backward(dY: torch.Tensor, input: torch.Tensor, mean, rstd, gamma, N, C, HxW, group, grad_input_mask): dX = torch.empty_like(input) dgamma = torch.empty_like(gamma) dbeta = torch.empty_like(gamma) return dX, dgamma, dbeta # ================================== Misc ========================================== # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml @register_meta(aten.roll.default) def meta_roll(input: torch.Tensor, shifts, dims): return input # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp @register_meta(aten._local_scalar_dense.default) def meta_local_scalar_dense(self: torch.Tensor): return 0 # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp @register_meta(aten.where.self) def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): result_type = torch.result_type(self, other) return torch.empty_like(self, dtype=result_type) @register_meta(aten.index.Tensor) def meta_index_Tensor(self, indices): assert indices, "at least one index must be provided" # aten::index is the internal advanced indexing implementation # checkIndexTensorTypes and expandTensors result: List[Optional[torch.Tensor]] = [] for i, index in enumerate(indices): if index is not None: assert index.dtype in [torch.long, torch.int8, torch.bool],\ "tensors used as indices must be long, byte or bool tensors" if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" for j in range(index.ndim): assert index.shape[j] == self.shape[ k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" result.append(nonzero.select(1, j)) else: result.append(index) else: result.append(index) indices = result assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" # expand_outplace import torch._refs as refs indices = list(refs._maybe_broadcast(*indices)) # add missing null tensors while len(indices) < self.ndim: indices.append(None) # hasContiguousSubspace # true if all non-null tensors are adjacent # See: # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency state = 0 has_contiguous_subspace = False for index in indices: if state == 0: if index is not None: state = 1 elif state == 1: if index is None: state = 2 else: if index is not None: break else: has_contiguous_subspace = True # transposeToFront # This is the logic that causes the newly inserted dimensions to show up # at the beginning of the tensor, if they're not contiguous if not has_contiguous_subspace: dims = [] transposed_indices = [] for i, index in enumerate(indices): if index is not None: dims.append(i) transposed_indices.append(index) for i, index in enumerate(indices): if index is None: dims.append(i) transposed_indices.append(index) self = self.permute(dims) indices = transposed_indices # AdvancedIndex::AdvancedIndex # Now we can assume the indices have contiguous subspace # This is simplified from AdvancedIndex which goes to more effort # to put the input and indices in a form so that TensorIterator can # take them. If we write a ref for this, probably that logic should # get implemented before_shape: List[int] = [] after_shape: List[int] = [] replacement_shape: List[int] = [] for dim, index in enumerate(indices): if index is None: if replacement_shape: after_shape.append(self.shape[dim]) else: before_shape.append(self.shape[dim]) else: replacement_shape = list(index.shape) return self.new_empty(before_shape + replacement_shape + after_shape) # ============================== Embedding ========================================= # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp @register_meta(aten.embedding_dense_backward.default) def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq): return torch.empty((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, device=grad_output.device, layout=grad_output.layout) # ============================== Dropout =========================================== # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp @register_meta(aten.native_dropout.default) def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): # notice that mask is bool output = torch.empty_like(input) mask = torch.empty_like(input, dtype=torch.bool) return output, mask # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp @register_meta(aten.native_dropout_backward.default) def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): return torch.empty_like(grad)