mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] avoid conflict of meta registry with torch 1.13.0. (#1530)
* [hotfix] avoid conflict of meta registry with torch 1.13.0. * [hotfix] avoid conflict of meta registry with torch 1.13.0.pull/1533/head
parent
b231430bcb
commit
112a1f0a8f
|
@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||
|
@ -14,16 +13,17 @@ meta_table = {}
|
|||
|
||||
|
||||
def register_meta(op, register_dispatcher=True):
|
||||
|
||||
def wrapper(f):
|
||||
|
||||
def add_func(op):
|
||||
meta_table[op] = f
|
||||
if register_dispatcher:
|
||||
name = (
|
||||
op.__name__
|
||||
if op._overloadname != "default"
|
||||
else op.overloadpacket.__name__
|
||||
)
|
||||
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
|
||||
try:
|
||||
meta_lib.impl(name, f)
|
||||
except:
|
||||
pass
|
||||
|
||||
tree_map(add_func, op)
|
||||
return f
|
||||
|
@ -44,6 +44,7 @@ def meta_conv(
|
|||
output_padding: List[int],
|
||||
groups: int,
|
||||
):
|
||||
|
||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
|
@ -120,14 +121,9 @@ def meta_conv(
|
|||
kernel_size[i],
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
)
|
||||
)
|
||||
))
|
||||
else:
|
||||
ret_shape.append(
|
||||
_formula(
|
||||
dims[i], padding[i], dilation[i], kernel_size[i], stride[i]
|
||||
)
|
||||
)
|
||||
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
||||
return ret_shape
|
||||
|
||||
def pick_memory_format():
|
||||
|
@ -156,9 +152,7 @@ def meta_conv(
|
|||
out_channels = weight.shape[0]
|
||||
if weight.shape[1] != input_tensor.shape[1] / groups:
|
||||
raise RuntimeError("Invalid channel dimensions")
|
||||
shape_out = calc_conv_nd_return_shape(
|
||||
dims, kernel_size, stride, padding, dilation
|
||||
)
|
||||
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||
mem_fmt = pick_memory_format()
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
|
@ -166,10 +160,8 @@ def meta_conv(
|
|||
|
||||
|
||||
@register_meta(aten.convolution_backward.default)
|
||||
def meta_conv_backward(
|
||||
grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask
|
||||
):
|
||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
|
||||
|
||||
|
||||
|
@ -189,16 +181,13 @@ def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor):
|
|||
return grad_in
|
||||
|
||||
|
||||
@register_meta([aten.roll.default, ])
|
||||
@register_meta(aten.roll.default)
|
||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@register_meta(aten.native_batch_norm.default)
|
||||
def meta_bn(
|
||||
input: torch.Tensor,
|
||||
weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
|
@ -208,10 +197,8 @@ def meta_bn(
|
|||
|
||||
|
||||
@register_meta(aten.native_batch_norm_backward.default)
|
||||
def meta_bn_backward(
|
||||
dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
running_mean, running_var, save_mean, save_invstd, train, eps, output_mask
|
||||
):
|
||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
||||
save_invstd, train, eps, output_mask):
|
||||
dX = torch.empty_like(input)
|
||||
dgamma = torch.empty_like(weight)
|
||||
dbeta = torch.empty_like(weight)
|
||||
|
@ -219,10 +206,7 @@ def meta_bn_backward(
|
|||
|
||||
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(
|
||||
input: torch.Tensor,
|
||||
normalized_shape, weight, bias, eps
|
||||
):
|
||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
n_input = input.size(1)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
|
@ -232,11 +216,8 @@ def meta_ln(
|
|||
|
||||
|
||||
@register_meta(aten.native_layer_norm_backward.default)
|
||||
def meta_ln_backward(
|
||||
dY: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
normalized_shape, mean, rstd, weight, bias, grad_input_mask
|
||||
):
|
||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||
grad_input_mask):
|
||||
dX = torch.empty_like(input)
|
||||
dgamma = torch.empty_like(weight)
|
||||
dbeta = torch.empty_like(bias)
|
||||
|
@ -245,7 +226,8 @@ def meta_ln_backward(
|
|||
|
||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||
def meta_adaptive_avg_pool2d_backward(
|
||||
grad_output: torch.Tensor, input: torch.Tensor,
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
):
|
||||
grad_input = torch.empty_like(input)
|
||||
return torch.empty_like(input)
|
||||
|
@ -266,7 +248,9 @@ def meta_index_Tensor(self, indices):
|
|||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
assert index.shape[j] == self.shape[
|
||||
k +
|
||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
else:
|
||||
result.append(index)
|
||||
|
|
Loading…
Reference in New Issue