mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
506 lines
19 KiB
506 lines
19 KiB
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py |
|
# should be activated for PyTorch version 1.12.0 and below |
|
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml |
|
# for more meta_registrations |
|
|
|
from typing import Callable, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch.utils._pytree import tree_map |
|
|
|
aten = torch.ops.aten |
|
|
|
meta_lib = torch.library.Library("aten", "IMPL", "Meta") |
|
|
|
meta_table = {} |
|
|
|
|
|
def register_meta(op, register_dispatcher=True): |
|
|
|
def wrapper(f): |
|
|
|
def add_func(op): |
|
meta_table[op] = f |
|
if register_dispatcher: |
|
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) |
|
try: |
|
meta_lib.impl(name, f) |
|
except: |
|
pass |
|
|
|
tree_map(add_func, op) |
|
return f |
|
|
|
return wrapper |
|
|
|
|
|
# ============================== Convolutions ====================================== |
|
# https://github.com/pytorch/pytorch/pull/79834 |
|
@register_meta(aten.convolution.default) |
|
def meta_conv( |
|
input_tensor: torch.Tensor, |
|
weight: torch.Tensor, |
|
bias: torch.Tensor, |
|
stride: List[int], |
|
padding: List[int], |
|
dilation: List[int], |
|
is_transposed: bool, |
|
output_padding: List[int], |
|
groups: int, |
|
): |
|
|
|
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: |
|
""" |
|
Formula to apply to calculate the length of some dimension of the output |
|
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html |
|
Args: |
|
ln: length of the dimension |
|
p: padding in that dim |
|
d: dilation in that dim |
|
k: kernel size in that dim |
|
s: stride in that dim |
|
Returns: |
|
The output length |
|
""" |
|
return (ln + 2 * p - d * (k - 1) - 1) // s + 1 |
|
|
|
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int: |
|
""" |
|
Formula to apply to calculate the length of some dimension of the output |
|
if transposed convolution is used. |
|
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html |
|
Args: |
|
ln: length of the dimension |
|
p: padding in that dim |
|
d: dilation in that dim |
|
k: kernel size in that dim |
|
s: stride in that dim |
|
op: output padding in that dim |
|
Returns: |
|
The output length |
|
""" |
|
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1 |
|
|
|
def calc_conv_nd_return_shape( |
|
dims: torch.Size, |
|
kernel_size: torch.Size, |
|
stride: Union[List[int], int], |
|
padding: Union[List[int], int], |
|
dilation: Union[List[int], int], |
|
output_padding: Optional[Union[List[int], int]] = None, |
|
): |
|
ret_shape = [] |
|
if isinstance(stride, int): |
|
stride = [stride] * len(dims) |
|
elif len(stride) == 1: |
|
stride = [stride[0]] * len(dims) |
|
|
|
if isinstance(padding, int): |
|
padding = [padding] * len(dims) |
|
elif len(padding) == 1: |
|
padding = [padding[0]] * len(dims) |
|
|
|
if isinstance(dilation, int): |
|
dilation = [dilation] * len(dims) |
|
elif len(dilation) == 1: |
|
dilation = [dilation[0]] * len(dims) |
|
|
|
output_padding_list: Optional[List[int]] = None |
|
if output_padding: |
|
if isinstance(output_padding, int): |
|
output_padding_list = [output_padding] * len(dims) |
|
elif len(output_padding) == 1: |
|
output_padding_list = [output_padding[0]] * len(dims) |
|
else: |
|
output_padding_list = output_padding |
|
|
|
for i in range(len(dims)): |
|
# If output_padding is present, we are dealing with a transposed convolution |
|
if output_padding_list: |
|
ret_shape.append( |
|
_formula_transposed( |
|
dims[i], |
|
padding[i], |
|
dilation[i], |
|
kernel_size[i], |
|
stride[i], |
|
output_padding_list[i], |
|
)) |
|
else: |
|
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) |
|
return ret_shape |
|
|
|
def pick_memory_format(): |
|
if input_tensor.is_contiguous(memory_format=torch.channels_last): |
|
return torch.channels_last |
|
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format): |
|
return torch.contiguous_format |
|
elif input_tensor.is_contiguous(memory_format=torch.preserve_format): |
|
return torch.preserve_format |
|
|
|
kernel_size = weight.shape[2:] |
|
dims = input_tensor.shape[2:] |
|
if is_transposed: |
|
out_channels = groups * weight.shape[1] |
|
|
|
shape_out = calc_conv_nd_return_shape( |
|
dims, |
|
kernel_size, |
|
stride, |
|
padding, |
|
dilation, |
|
output_padding, |
|
) |
|
|
|
else: |
|
out_channels = weight.shape[0] |
|
if weight.shape[1] != input_tensor.shape[1] / groups: |
|
raise RuntimeError("Invalid channel dimensions") |
|
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) |
|
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) |
|
mem_fmt = pick_memory_format() |
|
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] |
|
return out |
|
|
|
|
|
@register_meta(aten._convolution.default) |
|
def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], |
|
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, |
|
*extra_args): |
|
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) |
|
return out |
|
|
|
|
|
@register_meta(aten.convolution_backward.default) |
|
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, |
|
padding, dilation, transposed, output_padding, groups, output_mask): |
|
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp |
|
@register_meta(aten._adaptive_avg_pool2d_backward.default) |
|
def meta_adaptive_avg_pool2d_backward( |
|
grad_output: torch.Tensor, |
|
input: torch.Tensor, |
|
): |
|
grad_input = torch.empty_like(input) |
|
return grad_input |
|
|
|
|
|
# ================================ RNN ============================================= |
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp |
|
@register_meta(aten._cudnn_rnn.default) |
|
def meta_cuda_rnn( |
|
input, |
|
weight, |
|
weight_stride0, |
|
weight_buf, |
|
hx, |
|
cx, |
|
mode, |
|
hidden_size, |
|
proj_size, |
|
num_layers, |
|
batch_first, |
|
dropout, |
|
train, |
|
bidirectional, |
|
batch_sizes, |
|
dropout_state, |
|
): |
|
|
|
is_input_packed = len(batch_sizes) != 0 |
|
if is_input_packed: |
|
seq_length = len(batch_sizes) |
|
mini_batch = batch_sizes[0] |
|
batch_sizes_sum = input.shape[0] |
|
else: |
|
seq_length = input.shape[1] if batch_first else input.shape[0] |
|
mini_batch = input.shape[0] if batch_first else input.shape[1] |
|
batch_sizes_sum = -1 |
|
|
|
num_directions = 2 if bidirectional else 1 |
|
out_size = proj_size if proj_size != 0 else hidden_size |
|
if is_input_packed: |
|
out_shape = [batch_sizes_sum, out_size * num_directions] |
|
else: |
|
out_shape = ([mini_batch, seq_length, out_size * |
|
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) |
|
output = input.new_empty(out_shape) |
|
|
|
cell_shape = [num_layers * num_directions, mini_batch, hidden_size] |
|
cy = torch.empty(0) if cx is None else cx.new_empty(cell_shape) |
|
|
|
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size]) |
|
|
|
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python) |
|
reserve_shape = 0 if train else 0 |
|
reserve = input.new_empty(reserve_shape, dtype=torch.uint8) |
|
|
|
return output, hy, cy, reserve, weight_buf |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp |
|
@register_meta(aten._cudnn_rnn_backward.default) |
|
def meta_cudnn_rnn_backward(input: torch.Tensor, |
|
weight: torch.Tensor, |
|
weight_stride0: int, |
|
hx: torch.Tensor, |
|
cx: Optional[torch.Tensor] = None, |
|
*args, |
|
**kwargs): |
|
print(input, weight, hx, cx) |
|
grad_input = torch.empty_like(input) |
|
grad_weight = torch.empty_like(weight) |
|
grad_hx = torch.empty_like(hx) |
|
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') |
|
return grad_input, grad_weight, grad_hx, grad_cx |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp |
|
# ============================== Activations ======================================= |
|
@register_meta(aten.relu.default) |
|
def meta_relu(input: torch.Tensor): |
|
return torch.empty_like(input) |
|
|
|
|
|
@register_meta(aten.prelu.default) |
|
def meta_prelu(input: torch.Tensor, weight: torch.Tensor): |
|
return torch.empty_like(input) |
|
|
|
|
|
@register_meta(aten.hardswish.default) |
|
def meta_hardswish(input: torch.Tensor): |
|
return torch.empty_like(input) |
|
|
|
|
|
@register_meta(aten.hardtanh.default) |
|
def meta_hardtanh(input: torch.Tensor, min, max): |
|
return torch.empty_like(input) |
|
|
|
|
|
@register_meta(aten.hardswish_backward.default) |
|
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): |
|
grad_in = torch.empty_like(input) |
|
return grad_in |
|
|
|
|
|
@register_meta(aten.hardtanh_backward.default) |
|
def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int): |
|
grad_in = torch.empty_like(input) |
|
return grad_in |
|
|
|
|
|
# ============================== Normalization ===================================== |
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp |
|
@register_meta(aten.native_batch_norm.default) |
|
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): |
|
n_input = input.size(1) |
|
|
|
output = torch.empty_like(input) |
|
running_mean = torch.empty((n_input), device='meta') |
|
running_var = torch.empty((n_input), device='meta') |
|
return output, running_mean, running_var |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp |
|
@register_meta(aten.native_batch_norm_backward.default) |
|
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, |
|
save_invstd, train, eps, output_mask): |
|
dX = torch.empty_like(input) |
|
dgamma = torch.empty_like(weight) |
|
dbeta = torch.empty_like(weight) |
|
return dX, dgamma, dbeta |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp |
|
@register_meta(aten.cudnn_batch_norm.default) |
|
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): |
|
n_input = input.size(1) |
|
|
|
output = torch.empty_like(input) |
|
running_mean = torch.empty((n_input), device='meta') |
|
running_var = torch.empty((n_input), device='meta') |
|
reserve = torch.empty((0), dtype=torch.uint8, device='meta') |
|
return output, running_mean, running_var, reserve |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp |
|
# NB: CuDNN only implements the backward algorithm for batchnorm |
|
# in training mode (evaluation mode batchnorm has a different algorithm), |
|
# which is why this doesn't accept a 'training' parameter. |
|
@register_meta(aten.cudnn_batch_norm_backward.default) |
|
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, |
|
save_mean, save_invstd, eps, reserve): |
|
dX = torch.empty_like(input) |
|
dgamma = torch.empty_like(weight) |
|
dbeta = torch.empty_like(weight) |
|
return dX, dgamma, dbeta |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp |
|
@register_meta(aten.native_layer_norm.default) |
|
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): |
|
bs = input.size(0) |
|
n_input = input.size(1) |
|
|
|
output = torch.empty_like(input) |
|
running_mean = torch.empty((bs, n_input, 1), device='meta') |
|
running_var = torch.empty((bs, n_input, 1), device='meta') |
|
return output, running_mean, running_var |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp |
|
@register_meta(aten.native_layer_norm_backward.default) |
|
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, |
|
grad_input_mask): |
|
dX = torch.empty_like(input) |
|
dgamma = torch.empty_like(weight) |
|
dbeta = torch.empty_like(bias) |
|
return dX, dgamma, dbeta |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp |
|
@register_meta(aten.native_group_norm_backward.default) |
|
def meta_gn_backward(dY: torch.Tensor, input: torch.Tensor, mean, rstd, gamma, N, C, HxW, group, grad_input_mask): |
|
dX = torch.empty_like(input) |
|
dgamma = torch.empty_like(gamma) |
|
dbeta = torch.empty_like(gamma) |
|
return dX, dgamma, dbeta |
|
|
|
|
|
# ================================== Misc ========================================== |
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml |
|
@register_meta(aten.roll.default) |
|
def meta_roll(input: torch.Tensor, shifts, dims): |
|
return input |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp |
|
@register_meta(aten._local_scalar_dense.default) |
|
def meta_local_scalar_dense(self: torch.Tensor): |
|
return 0 |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp |
|
@register_meta(aten.where.self) |
|
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): |
|
result_type = torch.result_type(self, other) |
|
return torch.empty_like(self, dtype=result_type) |
|
|
|
|
|
@register_meta(aten.index.Tensor) |
|
def meta_index_Tensor(self, indices): |
|
assert indices, "at least one index must be provided" |
|
# aten::index is the internal advanced indexing implementation |
|
# checkIndexTensorTypes and expandTensors |
|
result: List[Optional[torch.Tensor]] = [] |
|
for i, index in enumerate(indices): |
|
if index is not None: |
|
assert index.dtype in [torch.long, torch.int8, torch.bool],\ |
|
"tensors used as indices must be long, byte or bool tensors" |
|
if index.dtype in [torch.int8, torch.bool]: |
|
nonzero = index.nonzero() |
|
k = len(result) |
|
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" |
|
for j in range(index.ndim): |
|
assert index.shape[j] == self.shape[ |
|
k + |
|
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" |
|
result.append(nonzero.select(1, j)) |
|
else: |
|
result.append(index) |
|
else: |
|
result.append(index) |
|
indices = result |
|
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" |
|
# expand_outplace |
|
import torch._refs as refs |
|
|
|
indices = list(refs._maybe_broadcast(*indices)) |
|
# add missing null tensors |
|
while len(indices) < self.ndim: |
|
indices.append(None) |
|
|
|
# hasContiguousSubspace |
|
# true if all non-null tensors are adjacent |
|
# See: |
|
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing |
|
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency |
|
state = 0 |
|
has_contiguous_subspace = False |
|
for index in indices: |
|
if state == 0: |
|
if index is not None: |
|
state = 1 |
|
elif state == 1: |
|
if index is None: |
|
state = 2 |
|
else: |
|
if index is not None: |
|
break |
|
else: |
|
has_contiguous_subspace = True |
|
|
|
# transposeToFront |
|
# This is the logic that causes the newly inserted dimensions to show up |
|
# at the beginning of the tensor, if they're not contiguous |
|
if not has_contiguous_subspace: |
|
dims = [] |
|
transposed_indices = [] |
|
for i, index in enumerate(indices): |
|
if index is not None: |
|
dims.append(i) |
|
transposed_indices.append(index) |
|
for i, index in enumerate(indices): |
|
if index is None: |
|
dims.append(i) |
|
transposed_indices.append(index) |
|
self = self.permute(dims) |
|
indices = transposed_indices |
|
|
|
# AdvancedIndex::AdvancedIndex |
|
# Now we can assume the indices have contiguous subspace |
|
# This is simplified from AdvancedIndex which goes to more effort |
|
# to put the input and indices in a form so that TensorIterator can |
|
# take them. If we write a ref for this, probably that logic should |
|
# get implemented |
|
before_shape: List[int] = [] |
|
after_shape: List[int] = [] |
|
replacement_shape: List[int] = [] |
|
for dim, index in enumerate(indices): |
|
if index is None: |
|
if replacement_shape: |
|
after_shape.append(self.shape[dim]) |
|
else: |
|
before_shape.append(self.shape[dim]) |
|
else: |
|
replacement_shape = list(index.shape) |
|
return self.new_empty(before_shape + replacement_shape + after_shape) |
|
|
|
|
|
# ============================== Embedding ========================================= |
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp |
|
@register_meta(aten.embedding_dense_backward.default) |
|
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, |
|
scale_grad_by_freq): |
|
return torch.empty((num_weights, grad_output.size(-1)), |
|
dtype=grad_output.dtype, |
|
device=grad_output.device, |
|
layout=grad_output.layout) |
|
|
|
|
|
# ============================== Dropout =========================================== |
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp |
|
@register_meta(aten.native_dropout.default) |
|
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): |
|
# notice that mask is bool |
|
output = torch.empty_like(input) |
|
mask = torch.empty_like(input, dtype=torch.bool) |
|
return output, mask |
|
|
|
|
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp |
|
@register_meta(aten.native_dropout_backward.default) |
|
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): |
|
return torch.empty_like(grad)
|
|
|