@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
import torch
from torch . utils . _pytree import tree_map
aten = torch . ops . aten
meta_lib = torch . library . Library ( " aten " , " IMPL " , " Meta " )
@ -14,16 +13,17 @@ meta_table = {}
def register_meta ( op , register_dispatcher = True ) :
def wrapper ( f ) :
def add_func ( op ) :
meta_table [ op ] = f
if register_dispatcher :
name = (
op . __name__
if op . _overloadname != " default "
else op . overloadpacket . __name__
)
meta_lib . impl ( name , f )
name = ( op . __name__ if op . _overloadname != " default " else op . overloadpacket . __name__ )
try :
meta_lib . impl ( name , f )
except :
pass
tree_map ( add_func , op )
return f
@ -44,6 +44,7 @@ def meta_conv(
output_padding : List [ int ] ,
groups : int ,
) :
def _formula ( ln : int , p : int , d : int , k : int , s : int ) - > int :
"""
Formula to apply to calculate the length of some dimension of the output
@ -120,14 +121,9 @@ def meta_conv(
kernel_size [ i ] ,
stride [ i ] ,
output_padding_list [ i ] ,
)
)
) )
else :
ret_shape . append (
_formula (
dims [ i ] , padding [ i ] , dilation [ i ] , kernel_size [ i ] , stride [ i ]
)
)
ret_shape . append ( _formula ( dims [ i ] , padding [ i ] , dilation [ i ] , kernel_size [ i ] , stride [ i ] ) )
return ret_shape
def pick_memory_format ( ) :
@ -156,20 +152,16 @@ def meta_conv(
out_channels = weight . shape [ 0 ]
if weight . shape [ 1 ] != input_tensor . shape [ 1 ] / groups :
raise RuntimeError ( " Invalid channel dimensions " )
shape_out = calc_conv_nd_return_shape (
dims , kernel_size , stride , padding , dilation
)
shape_out = calc_conv_nd_return_shape ( dims , kernel_size , stride , padding , dilation )
out = input_tensor . new_empty ( ( input_tensor . shape [ 0 ] , out_channels , * shape_out ) )
mem_fmt = pick_memory_format ( )
out = out . to ( memory_format = mem_fmt ) # type: ignore[call-overload]
out = out . to ( memory_format = mem_fmt ) # type: ignore[call-overload]
return out
@register_meta ( aten . convolution_backward . default )
def meta_conv_backward (
grad_output : torch . Tensor , input : torch . Tensor , weight : torch . Tensor ,
bias_sizes , stride , padding , dilation , transposed , output_padding , groups , output_mask
) :
def meta_conv_backward ( grad_output : torch . Tensor , input : torch . Tensor , weight : torch . Tensor , bias_sizes , stride ,
padding , dilation , transposed , output_padding , groups , output_mask ) :
return torch . empty_like ( input ) , torch . empty_like ( weight ) , torch . empty ( ( bias_sizes ) , device = ' meta ' )
@ -184,21 +176,18 @@ def meta_hardswish(input: torch.Tensor):
@register_meta ( aten . hardswish_backward . default )
def meta_hardswish_backward ( grad_out : torch . Tensor , input : torch . Tensor ) :
def meta_hardswish_backward ( grad_out : torch . Tensor , input : torch . Tensor ) :
grad_in = torch . empty_like ( input )
return grad_in
@register_meta ( [ aten . roll . default , ] )
def meta_roll ( input : torch . Tensor , shifts , dims ) :
@register_meta ( aten . roll . default )
def meta_roll ( input : torch . Tensor , shifts , dims ) :
return torch . empty_like ( input )
@register_meta ( aten . native_batch_norm . default )
def meta_bn (
input : torch . Tensor ,
weight , bias , running_mean , running_var , training , momentum , eps
) :
def meta_bn ( input : torch . Tensor , weight , bias , running_mean , running_var , training , momentum , eps ) :
n_input = input . size ( 1 )
output = torch . empty_like ( input )
@ -208,10 +197,8 @@ def meta_bn(
@register_meta ( aten . native_batch_norm_backward . default )
def meta_bn_backward (
dY : torch . Tensor , input : torch . Tensor , weight : torch . Tensor ,
running_mean , running_var , save_mean , save_invstd , train , eps , output_mask
) :
def meta_bn_backward ( dY : torch . Tensor , input : torch . Tensor , weight : torch . Tensor , running_mean , running_var , save_mean ,
save_invstd , train , eps , output_mask ) :
dX = torch . empty_like ( input )
dgamma = torch . empty_like ( weight )
dbeta = torch . empty_like ( weight )
@ -219,10 +206,7 @@ def meta_bn_backward(
@register_meta ( aten . native_layer_norm . default )
def meta_ln (
input : torch . Tensor ,
normalized_shape , weight , bias , eps
) :
def meta_ln ( input : torch . Tensor , normalized_shape , weight , bias , eps ) :
n_input = input . size ( 1 )
output = torch . empty_like ( input )
@ -232,11 +216,8 @@ def meta_ln(
@register_meta ( aten . native_layer_norm_backward . default )
def meta_ln_backward (
dY : torch . Tensor ,
input : torch . Tensor ,
normalized_shape , mean , rstd , weight , bias , grad_input_mask
) :
def meta_ln_backward ( dY : torch . Tensor , input : torch . Tensor , normalized_shape , mean , rstd , weight , bias ,
grad_input_mask ) :
dX = torch . empty_like ( input )
dgamma = torch . empty_like ( weight )
dbeta = torch . empty_like ( bias )
@ -245,7 +226,8 @@ def meta_ln_backward(
@register_meta ( aten . _adaptive_avg_pool2d_backward . default )
def meta_adaptive_avg_pool2d_backward (
grad_output : torch . Tensor , input : torch . Tensor ,
grad_output : torch . Tensor ,
input : torch . Tensor ,
) :
grad_input = torch . empty_like ( input )
return torch . empty_like ( input )
@ -266,7 +248,9 @@ def meta_index_Tensor(self, indices):
k = len ( result )
assert k + index . ndim < = self . ndim , f " too many indices for tensor of dimension { self . ndim } "
for j in range ( index . ndim ) :
assert index . shape [ j ] == self . shape [ k + j ] , f " The shape of the mask { index . shape } at index { i } does not match the shape of the indexed tensor { self . shape } at index { k + j } "
assert index . shape [ j ] == self . shape [
k +
j ] , f " The shape of the mask { index . shape } at index { i } does not match the shape of the indexed tensor { self . shape } at index { k + j } "
result . append ( nonzero . select ( 1 , j ) )
else :
result . append ( index )
@ -275,7 +259,7 @@ def meta_index_Tensor(self, indices):
indices = result
assert len ( indices ) < = self . ndim , f " too many indices for tensor of dimension { self . ndim } (got { len ( indices ) } ) "
# expand_outplace
import torch . _refs as refs # avoid import cycle in mypy
import torch . _refs as refs # avoid import cycle in mypy
indices = list ( refs . _maybe_broadcast ( * indices ) )
# add missing null tensors