Browse Source

[hotfix] avoid conflict of meta registry with torch 1.13.0. (#1530)

* [hotfix] avoid conflict of meta registry with torch 1.13.0.

* [hotfix] avoid conflict of meta registry with torch 1.13.0.
pull/1533/head
Super Daniel 2 years ago committed by GitHub
parent
commit
112a1f0a8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 74
      colossalai/fx/profiler/_meta_registrations.py

74
colossalai/fx/profiler/_meta_registrations.py

@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
import torch
from torch.utils._pytree import tree_map
aten = torch.ops.aten
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
@ -14,16 +13,17 @@ meta_table = {}
def register_meta(op, register_dispatcher=True):
def wrapper(f):
def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (
op.__name__
if op._overloadname != "default"
else op.overloadpacket.__name__
)
meta_lib.impl(name, f)
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
try:
meta_lib.impl(name, f)
except:
pass
tree_map(add_func, op)
return f
@ -44,6 +44,7 @@ def meta_conv(
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
@ -120,14 +121,9 @@ def meta_conv(
kernel_size[i],
stride[i],
output_padding_list[i],
)
)
))
else:
ret_shape.append(
_formula(
dims[i], padding[i], dilation[i], kernel_size[i], stride[i]
)
)
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
def pick_memory_format():
@ -156,20 +152,16 @@ def meta_conv(
out_channels = weight.shape[0]
if weight.shape[1] != input_tensor.shape[1] / groups:
raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape(
dims, kernel_size, stride, padding, dilation
)
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(
grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask
):
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
@ -184,21 +176,18 @@ def meta_hardswish(input: torch.Tensor):
@register_meta(aten.hardswish_backward.default)
def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor):
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
grad_in = torch.empty_like(input)
return grad_in
@register_meta([aten.roll.default, ])
def meta_roll(input:torch.Tensor, shifts, dims):
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return torch.empty_like(input)
@register_meta(aten.native_batch_norm.default)
def meta_bn(
input: torch.Tensor,
weight, bias, running_mean, running_var, training, momentum, eps
):
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
output = torch.empty_like(input)
@ -208,10 +197,8 @@ def meta_bn(
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(
dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
running_mean, running_var, save_mean, save_invstd, train, eps, output_mask
):
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
@ -219,10 +206,7 @@ def meta_bn_backward(
@register_meta(aten.native_layer_norm.default)
def meta_ln(
input: torch.Tensor,
normalized_shape, weight, bias, eps
):
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
n_input = input.size(1)
output = torch.empty_like(input)
@ -232,11 +216,8 @@ def meta_ln(
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(
dY: torch.Tensor,
input: torch.Tensor,
normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias)
@ -245,7 +226,8 @@ def meta_ln_backward(
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor, input: torch.Tensor,
grad_output: torch.Tensor,
input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return torch.empty_like(input)
@ -266,7 +248,9 @@ def meta_index_Tensor(self, indices):
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
@ -275,7 +259,7 @@ def meta_index_Tensor(self, indices):
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs # avoid import cycle in mypy
import torch._refs as refs # avoid import cycle in mypy
indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors

Loading…
Cancel
Save