mirror of https://github.com/hpcaitech/ColossalAI
[fx] refactored the file structure of patched function and module (#1238)
* [fx] refactored the file structure of patched function and module * polish codepull/1253/head
parent
17ed33350b
commit
7531c6271f
|
@ -1,225 +0,0 @@
|
|||
from curses import meta
|
||||
import operator
|
||||
import torch
|
||||
from .registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(operator.getitem)
|
||||
def operator_getitem(a, b):
|
||||
# copied from huggingface.utils.fx
|
||||
def to_concrete(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
concrete = torch.ones_like(t, device="cpu")
|
||||
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
|
||||
concrete = concrete.to(torch.int64)
|
||||
return concrete
|
||||
return t
|
||||
|
||||
if isinstance(a, torch.Tensor):
|
||||
# TODO: infer shape without performing the computation.
|
||||
if isinstance(b, tuple):
|
||||
b = tuple(map(to_concrete, b))
|
||||
else:
|
||||
b = to_concrete(b)
|
||||
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||||
return operator.getitem(a, b)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.matmul)
|
||||
def torch_matmul(input, other, *, out=None):
|
||||
# copied from huggingface.utils.fx
|
||||
d1 = input.dim()
|
||||
d2 = other.dim()
|
||||
shape = None
|
||||
if d1 == 1 and d2 == 1:
|
||||
shape = None
|
||||
elif d1 == 2 and d2 == 2:
|
||||
shape = (input.size(0), other.size(1))
|
||||
elif d1 == 1 and d2 == 2:
|
||||
shape = (other.size(1),)
|
||||
elif d1 == 2 and d1 == 1:
|
||||
shape = (input.size(0),)
|
||||
else:
|
||||
max_length = max(input.dim(), other.dim())
|
||||
shape1 = list(input.shape)
|
||||
shape2 = list(other.shape)
|
||||
if d1 == 1:
|
||||
shape1 = [1] + shape1
|
||||
if d2 == 1:
|
||||
shape2.append(1)
|
||||
shape1 = [-1] * (max_length - d1) + list(input.shape)
|
||||
shape2 = [-1] * (max_length - d2) + list(other.shape)
|
||||
shape = []
|
||||
for i in range(max_length):
|
||||
shape.append(max(shape1[i], shape2[i]))
|
||||
shape[-2] = shape1[-2]
|
||||
shape[-1] = shape2[-1]
|
||||
if d1 == 1:
|
||||
shape.pop(-2)
|
||||
if d2 == 1:
|
||||
shape.pop(-1)
|
||||
if shape is None:
|
||||
return torch.tensor(0.0, device="meta")
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.arange)
|
||||
def torch_arange(*args, **kwargs):
|
||||
n = len(args)
|
||||
step = 1
|
||||
if n == 1:
|
||||
start = 0
|
||||
end = args[0]
|
||||
elif n == 2:
|
||||
start, end = args
|
||||
else:
|
||||
start, end, step = args
|
||||
if isinstance(start, float):
|
||||
start = int(start)
|
||||
if isinstance(end, float):
|
||||
start = int(end)
|
||||
if isinstance(step, float):
|
||||
step = int(step)
|
||||
step = kwargs.get("step", step)
|
||||
dtype = kwargs.get("dtype")
|
||||
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.where)
|
||||
def torch_where(condition, x, y):
|
||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||
# so hack it by using addition
|
||||
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.abs)
|
||||
def torch_abs(input, *, out=None):
|
||||
assert out is None, 'out is not supported yet'
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.relu)
|
||||
def torch_nn_func_relu(input, inplace=False):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.repeat)
|
||||
def torch_tensor_repeat(self, *sizes):
|
||||
shape = list(self.shape)
|
||||
for i, x in enumerate(sizes):
|
||||
shape[i] *= x
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.index_select)
|
||||
def torch_index_select(input, dim, index, *, out=None):
|
||||
shape = list(input.shape)
|
||||
shape[dim] = len(index)
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.index_select)
|
||||
def torch_tensor_index_select(self, dim, index):
|
||||
return torch_index_select(self, dim, index)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.embedding)
|
||||
def torch_nn_functional_embedding(input,
|
||||
weight,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.0,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False):
|
||||
return torch.empty(*input.shape, weight.shape[-1], device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if out is not None:
|
||||
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||
batch_size, n, m = input.shape
|
||||
_, _, p = mat2.shape
|
||||
return torch.empty(batch_size, n, p, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.squeeze)
|
||||
def torch_squeeze(input, dim=None):
|
||||
shape = list(input.shape)
|
||||
if dim is not None:
|
||||
if dim < 0:
|
||||
dim = input.dim() + dim
|
||||
if shape[dim] == 1:
|
||||
shape.pop(dim)
|
||||
else:
|
||||
new_shape = []
|
||||
for dim_value in shape:
|
||||
if dim_value == 1:
|
||||
continue
|
||||
new_shape.append(dim_value)
|
||||
shape = new_shape
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.squeeze)
|
||||
def torch_tensor_squeeze(self, dim=None):
|
||||
return torch_squeeze(self, dim)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.unsqueeze)
|
||||
def torch_unsqueeze(input, dim):
|
||||
shape = list(input.shape)
|
||||
if dim < 0:
|
||||
dim = input.dim() + 1 + dim
|
||||
shape.insert(dim, 1)
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.unsqueeze)
|
||||
def torch_tensor_unsqueeze(self, dim):
|
||||
return torch_unsqueeze(self, dim)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.layer_norm)
|
||||
def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.batch_norm)
|
||||
def torch_nn_func_batchnorm(input,
|
||||
running_mean,
|
||||
running_var,
|
||||
weight=None,
|
||||
bias=None,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=1e-05):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.var_mean)
|
||||
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
|
||||
assert out is None, 'saving to out is not supported yet'
|
||||
var = torch.empty(1).squeeze(0).to('meta')
|
||||
mean = torch.empty(1).squeeze(0).to('meta')
|
||||
return var, mean
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.cat)
|
||||
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||
if dim is None and axis is None:
|
||||
dim = 0
|
||||
if dim is None and axis is not None:
|
||||
dim = axis
|
||||
if dim < 0:
|
||||
dim = tensors[0].dim() + dim
|
||||
shapes = [t.shape for t in tensors]
|
||||
shape = list(shapes[0])
|
||||
concatenated_dim = sum(shape[dim] for shape in shapes)
|
||||
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
|
||||
return torch.empty(final_shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.roll)
|
||||
def torch_roll(input, shifts, dims=None):
|
||||
return torch.empty(input.shape, device='meta')
|
|
@ -0,0 +1,6 @@
|
|||
from .activation_function import *
|
||||
from .arithmetic import *
|
||||
from .embedding import *
|
||||
from .normalization import *
|
||||
from .python_ops import *
|
||||
from .torch_ops import *
|
|
@ -0,0 +1,7 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.relu)
|
||||
def torch_nn_func_relu(input, inplace=False):
|
||||
return torch.empty(input.shape, device='meta')
|
|
@ -0,0 +1,63 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.matmul)
|
||||
def torch_matmul(input, other, *, out=None):
|
||||
# copied from huggingface.utils.fx
|
||||
d1 = input.dim()
|
||||
d2 = other.dim()
|
||||
shape = None
|
||||
if d1 == 1 and d2 == 1:
|
||||
shape = None
|
||||
elif d1 == 2 and d2 == 2:
|
||||
shape = (input.size(0), other.size(1))
|
||||
elif d1 == 1 and d2 == 2:
|
||||
shape = (other.size(1),)
|
||||
elif d1 == 2 and d1 == 1:
|
||||
shape = (input.size(0),)
|
||||
else:
|
||||
max_length = max(input.dim(), other.dim())
|
||||
shape1 = list(input.shape)
|
||||
shape2 = list(other.shape)
|
||||
if d1 == 1:
|
||||
shape1 = [1] + shape1
|
||||
if d2 == 1:
|
||||
shape2.append(1)
|
||||
shape1 = [-1] * (max_length - d1) + list(input.shape)
|
||||
shape2 = [-1] * (max_length - d2) + list(other.shape)
|
||||
shape = []
|
||||
for i in range(max_length):
|
||||
shape.append(max(shape1[i], shape2[i]))
|
||||
shape[-2] = shape1[-2]
|
||||
shape[-1] = shape2[-1]
|
||||
if d1 == 1:
|
||||
shape.pop(-2)
|
||||
if d2 == 1:
|
||||
shape.pop(-1)
|
||||
if shape is None:
|
||||
return torch.tensor(0.0, device="meta")
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.abs)
|
||||
def torch_abs(input, *, out=None):
|
||||
assert out is None, 'out is not supported yet'
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if out is not None:
|
||||
raise ValueError("Don't support in-place abs for MetaTensor analysis")
|
||||
batch_size, n, m = input.shape
|
||||
_, _, p = mat2.shape
|
||||
return torch.empty(batch_size, n, p, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.var_mean)
|
||||
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
|
||||
assert out is None, 'saving to out is not supported yet'
|
||||
var = torch.empty(1).squeeze(0).to('meta')
|
||||
mean = torch.empty(1).squeeze(0).to('meta')
|
||||
return var, mean
|
|
@ -0,0 +1,13 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.embedding)
|
||||
def torch_nn_functional_embedding(input,
|
||||
weight,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.0,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False):
|
||||
return torch.empty(*input.shape, weight.shape[-1], device="meta")
|
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.layer_norm)
|
||||
def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.nn.functional.batch_norm)
|
||||
def torch_nn_func_batchnorm(input,
|
||||
running_mean,
|
||||
running_var,
|
||||
weight=None,
|
||||
bias=None,
|
||||
training=False,
|
||||
momentum=0.1,
|
||||
eps=1e-05):
|
||||
return torch.empty(input.shape, device='meta')
|
|
@ -0,0 +1,24 @@
|
|||
import operator
|
||||
import torch
|
||||
from ..registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(operator.getitem)
|
||||
def operator_getitem(a, b):
|
||||
# copied from huggingface.utils.fx
|
||||
def to_concrete(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
concrete = torch.ones_like(t, device="cpu")
|
||||
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
|
||||
concrete = concrete.to(torch.int64)
|
||||
return concrete
|
||||
return t
|
||||
|
||||
if isinstance(a, torch.Tensor):
|
||||
# TODO: infer shape without performing the computation.
|
||||
if isinstance(b, tuple):
|
||||
b = tuple(map(to_concrete, b))
|
||||
else:
|
||||
b = to_concrete(b)
|
||||
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
|
||||
return operator.getitem(a, b)
|
|
@ -0,0 +1,108 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.arange)
|
||||
def torch_arange(*args, **kwargs):
|
||||
n = len(args)
|
||||
step = 1
|
||||
if n == 1:
|
||||
start = 0
|
||||
end = args[0]
|
||||
elif n == 2:
|
||||
start, end = args
|
||||
else:
|
||||
start, end, step = args
|
||||
if isinstance(start, float):
|
||||
start = int(start)
|
||||
if isinstance(end, float):
|
||||
start = int(end)
|
||||
if isinstance(step, float):
|
||||
step = int(step)
|
||||
step = kwargs.get("step", step)
|
||||
dtype = kwargs.get("dtype")
|
||||
return torch.empty((end - start) // step, dtype=dtype, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.where)
|
||||
def torch_where(condition, x, y):
|
||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||
# so hack it by using addition
|
||||
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.repeat)
|
||||
def torch_tensor_repeat(self, *sizes):
|
||||
shape = list(self.shape)
|
||||
for i, x in enumerate(sizes):
|
||||
shape[i] *= x
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.index_select)
|
||||
def torch_index_select(input, dim, index, *, out=None):
|
||||
shape = list(input.shape)
|
||||
shape[dim] = len(index)
|
||||
return torch.empty(*shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.index_select)
|
||||
def torch_tensor_index_select(self, dim, index):
|
||||
return torch_index_select(self, dim, index)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.squeeze)
|
||||
def torch_squeeze(input, dim=None):
|
||||
shape = list(input.shape)
|
||||
if dim is not None:
|
||||
if dim < 0:
|
||||
dim = input.dim() + dim
|
||||
if shape[dim] == 1:
|
||||
shape.pop(dim)
|
||||
else:
|
||||
new_shape = []
|
||||
for dim_value in shape:
|
||||
if dim_value == 1:
|
||||
continue
|
||||
new_shape.append(dim_value)
|
||||
shape = new_shape
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.squeeze)
|
||||
def torch_tensor_squeeze(self, dim=None):
|
||||
return torch_squeeze(self, dim)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.unsqueeze)
|
||||
def torch_unsqueeze(input, dim):
|
||||
shape = list(input.shape)
|
||||
if dim < 0:
|
||||
dim = input.dim() + 1 + dim
|
||||
shape.insert(dim, 1)
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.Tensor.unsqueeze)
|
||||
def torch_tensor_unsqueeze(self, dim):
|
||||
return torch_unsqueeze(self, dim)
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.cat)
|
||||
def torch_cat(tensors, dim=None, axis=None, *, out=None):
|
||||
if dim is None and axis is None:
|
||||
dim = 0
|
||||
if dim is None and axis is not None:
|
||||
dim = axis
|
||||
if dim < 0:
|
||||
dim = tensors[0].dim() + dim
|
||||
shapes = [t.shape for t in tensors]
|
||||
shape = list(shapes[0])
|
||||
concatenated_dim = sum(shape[dim] for shape in shapes)
|
||||
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:]
|
||||
return torch.empty(final_shape, device="meta")
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.roll)
|
||||
def torch_roll(input, shifts, dims=None):
|
||||
return torch.empty(input.shape, device='meta')
|
|
@ -0,0 +1,6 @@
|
|||
from .activation_function import *
|
||||
from .convolution import *
|
||||
from .embedding import *
|
||||
from .linear import *
|
||||
from .normalization import *
|
||||
from .pooling import *
|
|
@ -0,0 +1,11 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_module
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.ReLU)
|
||||
@meta_patched_module.register(torch.nn.Sigmoid)
|
||||
@meta_patched_module.register(torch.nn.GELU)
|
||||
@meta_patched_module.register(torch.nn.Tanh)
|
||||
@meta_patched_module.register(torch.nn.ReLU6)
|
||||
def torch_nn_non_linear_act(self, input):
|
||||
return torch.empty(input.shape, device='meta')
|
|
@ -0,0 +1,57 @@
|
|||
import math
|
||||
import torch
|
||||
from ..registry import meta_patched_module
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Conv1d)
|
||||
def torch_nn_conv1d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
|
||||
l_in = input.shape[-1]
|
||||
c_out = self.out_channels
|
||||
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||
result_shape = input.shape[:-2] + (
|
||||
c_out,
|
||||
l_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Conv2d)
|
||||
def torch_nn_conv2d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d
|
||||
h_in, w_in = input.shape[-2:]
|
||||
c_out = self.out_channels
|
||||
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
|
||||
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
|
||||
result_shape = input.shape[:-3] + (
|
||||
c_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Conv3d)
|
||||
def torch_nn_conv3d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d
|
||||
d_in, h_in, w_in = input.shape[-3:]
|
||||
c_out = self.out_channels
|
||||
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
|
||||
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
|
||||
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
|
||||
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
|
||||
result_shape = input.shape[:-4] + (
|
||||
c_out,
|
||||
d_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_module
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Embedding)
|
||||
def torch_nn_embedding(self, input):
|
||||
result_shape = input.shape + (self.embedding_dim,)
|
||||
return torch.empty(result_shape, device='meta')
|
|
@ -0,0 +1,9 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_module
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Linear)
|
||||
def torch_nn_linear(self, input):
|
||||
last_dim = input.shape[-1]
|
||||
assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch'
|
||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
|
@ -0,0 +1,20 @@
|
|||
import torch
|
||||
from ..registry import meta_patched_module
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.LayerNorm)
|
||||
@meta_patched_module.register(torch.nn.GroupNorm)
|
||||
@meta_patched_module.register(torch.nn.BatchNorm1d)
|
||||
@meta_patched_module.register(torch.nn.BatchNorm2d)
|
||||
@meta_patched_module.register(torch.nn.BatchNorm3d)
|
||||
def torch_nn_normalize(self, input):
|
||||
# check shape
|
||||
if isinstance(self, torch.nn.BatchNorm1d):
|
||||
assert input.dim() in [2, 3]
|
||||
elif isinstance(self, torch.nn.BatchNorm2d):
|
||||
assert input.dim() == 4
|
||||
elif isinstance(self, torch.nn.BatchNorm3d):
|
||||
assert input.dim() == 5
|
||||
|
||||
# normalization maintain the same shape as the input
|
||||
return input.clone()
|
|
@ -1,91 +1,6 @@
|
|||
import math
|
||||
import torch
|
||||
from .registry import meta_patched_module
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Linear)
|
||||
def torch_nn_linear(self, input):
|
||||
last_dim = input.shape[-1]
|
||||
assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch'
|
||||
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.LayerNorm)
|
||||
@meta_patched_module.register(torch.nn.GroupNorm)
|
||||
@meta_patched_module.register(torch.nn.BatchNorm1d)
|
||||
@meta_patched_module.register(torch.nn.BatchNorm2d)
|
||||
@meta_patched_module.register(torch.nn.BatchNorm3d)
|
||||
def torch_nn_normalize(self, input):
|
||||
# check shape
|
||||
if isinstance(self, torch.nn.BatchNorm1d):
|
||||
assert input.dim() in [2, 3]
|
||||
elif isinstance(self, torch.nn.BatchNorm2d):
|
||||
assert input.dim() == 4
|
||||
elif isinstance(self, torch.nn.BatchNorm3d):
|
||||
assert input.dim() == 5
|
||||
|
||||
# normalization maintain the same shape as the input
|
||||
return input.clone()
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Embedding)
|
||||
def torch_nn_embedding(self, input):
|
||||
result_shape = input.shape + (self.embedding_dim,)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Conv1d)
|
||||
def torch_nn_conv1d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
|
||||
l_in = input.shape[-1]
|
||||
c_out = self.out_channels
|
||||
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||
result_shape = input.shape[:-2] + (
|
||||
c_out,
|
||||
l_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Conv2d)
|
||||
def torch_nn_conv2d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d
|
||||
h_in, w_in = input.shape[-2:]
|
||||
c_out = self.out_channels
|
||||
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
|
||||
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
|
||||
result_shape = input.shape[:-3] + (
|
||||
c_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.Conv3d)
|
||||
def torch_nn_conv3d(self, input):
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d
|
||||
d_in, h_in, w_in = input.shape[-3:]
|
||||
c_out = self.out_channels
|
||||
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
|
||||
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
|
||||
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
|
||||
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
|
||||
result_shape = input.shape[:-4] + (
|
||||
c_out,
|
||||
d_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
from ..registry import meta_patched_module
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.AvgPool1d)
|
||||
|
@ -274,10 +189,4 @@ def torch_nn_adapative_pooling_3d(self, input):
|
|||
self.output_size,
|
||||
self.output_size,
|
||||
)
|
||||
return torch.empty(result_shape, device='meta')
|
||||
|
||||
|
||||
@meta_patched_module.register(torch.nn.ReLU)
|
||||
@meta_patched_module.register(torch.nn.ReLU6)
|
||||
def torch_nn_func_relu(self, input):
|
||||
return torch.empty(input.shape, device='meta')
|
||||
return torch.empty(result_shape, device='meta')
|
Loading…
Reference in New Issue