From 7531c6271f143304452fb6af35d944d7be370bec Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 12 Jul 2022 15:01:58 +0800 Subject: [PATCH] [fx] refactored the file structure of patched function and module (#1238) * [fx] refactored the file structure of patched function and module * polish code --- .../fx/tracer/meta_patch/patched_function.py | 225 ------------------ .../meta_patch/patched_function/__init__.py | 6 + .../patched_function/activation_function.py | 7 + .../meta_patch/patched_function/arithmetic.py | 63 +++++ .../meta_patch/patched_function/embedding.py | 13 + .../patched_function/normalization.py | 19 ++ .../meta_patch/patched_function/python_ops.py | 24 ++ .../meta_patch/patched_function/torch_ops.py | 108 +++++++++ .../meta_patch/patched_module/__init__.py | 6 + .../patched_module/activation_function.py | 11 + .../meta_patch/patched_module/convolution.py | 57 +++++ .../meta_patch/patched_module/embedding.py | 8 + .../meta_patch/patched_module/linear.py | 9 + .../patched_module/normalization.py | 20 ++ .../pooling.py} | 95 +------- 15 files changed, 353 insertions(+), 318 deletions(-) delete mode 100644 colossalai/fx/tracer/meta_patch/patched_function.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_function/__init__.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_function/activation_function.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_function/embedding.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_function/normalization.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_function/python_ops.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_module/__init__.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_module/activation_function.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_module/convolution.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_module/embedding.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_module/linear.py create mode 100644 colossalai/fx/tracer/meta_patch/patched_module/normalization.py rename colossalai/fx/tracer/meta_patch/{patched_module.py => patched_module/pooling.py} (63%) diff --git a/colossalai/fx/tracer/meta_patch/patched_function.py b/colossalai/fx/tracer/meta_patch/patched_function.py deleted file mode 100644 index dd6312ccb..000000000 --- a/colossalai/fx/tracer/meta_patch/patched_function.py +++ /dev/null @@ -1,225 +0,0 @@ -from curses import meta -import operator -import torch -from .registry import meta_patched_function - - -@meta_patched_function.register(operator.getitem) -def operator_getitem(a, b): - # copied from huggingface.utils.fx - def to_concrete(t): - if isinstance(t, torch.Tensor): - concrete = torch.ones_like(t, device="cpu") - if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]: - concrete = concrete.to(torch.int64) - return concrete - return t - - if isinstance(a, torch.Tensor): - # TODO: infer shape without performing the computation. - if isinstance(b, tuple): - b = tuple(map(to_concrete, b)) - else: - b = to_concrete(b) - return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") - return operator.getitem(a, b) - - -@meta_patched_function.register(torch.matmul) -def torch_matmul(input, other, *, out=None): - # copied from huggingface.utils.fx - d1 = input.dim() - d2 = other.dim() - shape = None - if d1 == 1 and d2 == 1: - shape = None - elif d1 == 2 and d2 == 2: - shape = (input.size(0), other.size(1)) - elif d1 == 1 and d2 == 2: - shape = (other.size(1),) - elif d1 == 2 and d1 == 1: - shape = (input.size(0),) - else: - max_length = max(input.dim(), other.dim()) - shape1 = list(input.shape) - shape2 = list(other.shape) - if d1 == 1: - shape1 = [1] + shape1 - if d2 == 1: - shape2.append(1) - shape1 = [-1] * (max_length - d1) + list(input.shape) - shape2 = [-1] * (max_length - d2) + list(other.shape) - shape = [] - for i in range(max_length): - shape.append(max(shape1[i], shape2[i])) - shape[-2] = shape1[-2] - shape[-1] = shape2[-1] - if d1 == 1: - shape.pop(-2) - if d2 == 1: - shape.pop(-1) - if shape is None: - return torch.tensor(0.0, device="meta") - return torch.empty(*shape, device="meta") - - -@meta_patched_function.register(torch.arange) -def torch_arange(*args, **kwargs): - n = len(args) - step = 1 - if n == 1: - start = 0 - end = args[0] - elif n == 2: - start, end = args - else: - start, end, step = args - if isinstance(start, float): - start = int(start) - if isinstance(end, float): - start = int(end) - if isinstance(step, float): - step = int(step) - step = kwargs.get("step", step) - dtype = kwargs.get("dtype") - return torch.empty((end - start) // step, dtype=dtype, device="meta") - - -@meta_patched_function.register(torch.where) -def torch_where(condition, x, y): - # torch.where returns the broadcasted tensor of condition, x, and y, - # so hack it by using addition - return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") - - -@meta_patched_function.register(torch.abs) -def torch_abs(input, *, out=None): - assert out is None, 'out is not supported yet' - return torch.empty(input.shape, device='meta') - - -@meta_patched_function.register(torch.nn.functional.relu) -def torch_nn_func_relu(input, inplace=False): - return torch.empty(input.shape, device='meta') - - -@meta_patched_function.register(torch.Tensor.repeat) -def torch_tensor_repeat(self, *sizes): - shape = list(self.shape) - for i, x in enumerate(sizes): - shape[i] *= x - return torch.empty(shape, device="meta") - - -@meta_patched_function.register(torch.index_select) -def torch_index_select(input, dim, index, *, out=None): - shape = list(input.shape) - shape[dim] = len(index) - return torch.empty(*shape, device="meta") - - -@meta_patched_function.register(torch.Tensor.index_select) -def torch_tensor_index_select(self, dim, index): - return torch_index_select(self, dim, index) - - -@meta_patched_function.register(torch.nn.functional.embedding) -def torch_nn_functional_embedding(input, - weight, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False): - return torch.empty(*input.shape, weight.shape[-1], device="meta") - - -@meta_patched_function.register(torch.bmm) -def torch_bmm(input, mat2, *, out=None): - if out is not None: - raise ValueError("Don't support in-place abs for MetaTensor analysis") - batch_size, n, m = input.shape - _, _, p = mat2.shape - return torch.empty(batch_size, n, p, device="meta") - - -@meta_patched_function.register(torch.squeeze) -def torch_squeeze(input, dim=None): - shape = list(input.shape) - if dim is not None: - if dim < 0: - dim = input.dim() + dim - if shape[dim] == 1: - shape.pop(dim) - else: - new_shape = [] - for dim_value in shape: - if dim_value == 1: - continue - new_shape.append(dim_value) - shape = new_shape - return torch.empty(shape, device="meta") - - -@meta_patched_function.register(torch.Tensor.squeeze) -def torch_tensor_squeeze(self, dim=None): - return torch_squeeze(self, dim) - - -@meta_patched_function.register(torch.unsqueeze) -def torch_unsqueeze(input, dim): - shape = list(input.shape) - if dim < 0: - dim = input.dim() + 1 + dim - shape.insert(dim, 1) - return torch.empty(shape, device="meta") - - -@meta_patched_function.register(torch.Tensor.unsqueeze) -def torch_tensor_unsqueeze(self, dim): - return torch_unsqueeze(self, dim) - - -@meta_patched_function.register(torch.nn.functional.layer_norm) -def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05): - return torch.empty(input.shape, device='meta') - - -@meta_patched_function.register(torch.nn.functional.batch_norm) -def torch_nn_func_batchnorm(input, - running_mean, - running_var, - weight=None, - bias=None, - training=False, - momentum=0.1, - eps=1e-05): - return torch.empty(input.shape, device='meta') - - -@meta_patched_function.register(torch.var_mean) -def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): - assert out is None, 'saving to out is not supported yet' - var = torch.empty(1).squeeze(0).to('meta') - mean = torch.empty(1).squeeze(0).to('meta') - return var, mean - - -@meta_patched_function.register(torch.cat) -def torch_cat(tensors, dim=None, axis=None, *, out=None): - if dim is None and axis is None: - dim = 0 - if dim is None and axis is not None: - dim = axis - if dim < 0: - dim = tensors[0].dim() + dim - shapes = [t.shape for t in tensors] - shape = list(shapes[0]) - concatenated_dim = sum(shape[dim] for shape in shapes) - final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] - return torch.empty(final_shape, device="meta") - - -@meta_patched_function.register(torch.roll) -def torch_roll(input, shifts, dims=None): - return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_function/__init__.py b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py new file mode 100644 index 000000000..ca20ac0a9 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/__init__.py @@ -0,0 +1,6 @@ +from .activation_function import * +from .arithmetic import * +from .embedding import * +from .normalization import * +from .python_ops import * +from .torch_ops import * \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py new file mode 100644 index 000000000..d710098c7 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py @@ -0,0 +1,7 @@ +import torch +from ..registry import meta_patched_function + + +@meta_patched_function.register(torch.nn.functional.relu) +def torch_nn_func_relu(input, inplace=False): + return torch.empty(input.shape, device='meta') \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py new file mode 100644 index 000000000..3077262db --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -0,0 +1,63 @@ +import torch +from ..registry import meta_patched_function + + +@meta_patched_function.register(torch.matmul) +def torch_matmul(input, other, *, out=None): + # copied from huggingface.utils.fx + d1 = input.dim() + d2 = other.dim() + shape = None + if d1 == 1 and d2 == 1: + shape = None + elif d1 == 2 and d2 == 2: + shape = (input.size(0), other.size(1)) + elif d1 == 1 and d2 == 2: + shape = (other.size(1),) + elif d1 == 2 and d1 == 1: + shape = (input.size(0),) + else: + max_length = max(input.dim(), other.dim()) + shape1 = list(input.shape) + shape2 = list(other.shape) + if d1 == 1: + shape1 = [1] + shape1 + if d2 == 1: + shape2.append(1) + shape1 = [-1] * (max_length - d1) + list(input.shape) + shape2 = [-1] * (max_length - d2) + list(other.shape) + shape = [] + for i in range(max_length): + shape.append(max(shape1[i], shape2[i])) + shape[-2] = shape1[-2] + shape[-1] = shape2[-1] + if d1 == 1: + shape.pop(-2) + if d2 == 1: + shape.pop(-1) + if shape is None: + return torch.tensor(0.0, device="meta") + return torch.empty(*shape, device="meta") + + +@meta_patched_function.register(torch.abs) +def torch_abs(input, *, out=None): + assert out is None, 'out is not supported yet' + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.bmm) +def torch_bmm(input, mat2, *, out=None): + if out is not None: + raise ValueError("Don't support in-place abs for MetaTensor analysis") + batch_size, n, m = input.shape + _, _, p = mat2.shape + return torch.empty(batch_size, n, p, device="meta") + + +@meta_patched_function.register(torch.var_mean) +def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): + assert out is None, 'saving to out is not supported yet' + var = torch.empty(1).squeeze(0).to('meta') + mean = torch.empty(1).squeeze(0).to('meta') + return var, mean diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py new file mode 100644 index 000000000..42fb359b5 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py @@ -0,0 +1,13 @@ +import torch +from ..registry import meta_patched_function + + +@meta_patched_function.register(torch.nn.functional.embedding) +def torch_nn_functional_embedding(input, + weight, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False): + return torch.empty(*input.shape, weight.shape[-1], device="meta") \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py new file mode 100644 index 000000000..80d034f9a --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py @@ -0,0 +1,19 @@ +import torch +from ..registry import meta_patched_function + + +@meta_patched_function.register(torch.nn.functional.layer_norm) +def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + return torch.empty(input.shape, device='meta') + + +@meta_patched_function.register(torch.nn.functional.batch_norm) +def torch_nn_func_batchnorm(input, + running_mean, + running_var, + weight=None, + bias=None, + training=False, + momentum=0.1, + eps=1e-05): + return torch.empty(input.shape, device='meta') \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py new file mode 100644 index 000000000..ac1fe0c27 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py @@ -0,0 +1,24 @@ +import operator +import torch +from ..registry import meta_patched_function + + +@meta_patched_function.register(operator.getitem) +def operator_getitem(a, b): + # copied from huggingface.utils.fx + def to_concrete(t): + if isinstance(t, torch.Tensor): + concrete = torch.ones_like(t, device="cpu") + if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]: + concrete = concrete.to(torch.int64) + return concrete + return t + + if isinstance(a, torch.Tensor): + # TODO: infer shape without performing the computation. + if isinstance(b, tuple): + b = tuple(map(to_concrete, b)) + else: + b = to_concrete(b) + return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta") + return operator.getitem(a, b) diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py new file mode 100644 index 000000000..e3342a646 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -0,0 +1,108 @@ +import torch +from ..registry import meta_patched_function + + +@meta_patched_function.register(torch.arange) +def torch_arange(*args, **kwargs): + n = len(args) + step = 1 + if n == 1: + start = 0 + end = args[0] + elif n == 2: + start, end = args + else: + start, end, step = args + if isinstance(start, float): + start = int(start) + if isinstance(end, float): + start = int(end) + if isinstance(step, float): + step = int(step) + step = kwargs.get("step", step) + dtype = kwargs.get("dtype") + return torch.empty((end - start) // step, dtype=dtype, device="meta") + + +@meta_patched_function.register(torch.where) +def torch_where(condition, x, y): + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") + + +@meta_patched_function.register(torch.Tensor.repeat) +def torch_tensor_repeat(self, *sizes): + shape = list(self.shape) + for i, x in enumerate(sizes): + shape[i] *= x + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.index_select) +def torch_index_select(input, dim, index, *, out=None): + shape = list(input.shape) + shape[dim] = len(index) + return torch.empty(*shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.index_select) +def torch_tensor_index_select(self, dim, index): + return torch_index_select(self, dim, index) + + +@meta_patched_function.register(torch.squeeze) +def torch_squeeze(input, dim=None): + shape = list(input.shape) + if dim is not None: + if dim < 0: + dim = input.dim() + dim + if shape[dim] == 1: + shape.pop(dim) + else: + new_shape = [] + for dim_value in shape: + if dim_value == 1: + continue + new_shape.append(dim_value) + shape = new_shape + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.squeeze) +def torch_tensor_squeeze(self, dim=None): + return torch_squeeze(self, dim) + + +@meta_patched_function.register(torch.unsqueeze) +def torch_unsqueeze(input, dim): + shape = list(input.shape) + if dim < 0: + dim = input.dim() + 1 + dim + shape.insert(dim, 1) + return torch.empty(shape, device="meta") + + +@meta_patched_function.register(torch.Tensor.unsqueeze) +def torch_tensor_unsqueeze(self, dim): + return torch_unsqueeze(self, dim) + + +@meta_patched_function.register(torch.cat) +def torch_cat(tensors, dim=None, axis=None, *, out=None): + if dim is None and axis is None: + dim = 0 + if dim is None and axis is not None: + dim = axis + if dim < 0: + dim = tensors[0].dim() + dim + shapes = [t.shape for t in tensors] + shape = list(shapes[0]) + concatenated_dim = sum(shape[dim] for shape in shapes) + final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] + return torch.empty(final_shape, device="meta") + + +@meta_patched_function.register(torch.roll) +def torch_roll(input, shifts, dims=None): + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py new file mode 100644 index 000000000..bd550487c --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py @@ -0,0 +1,6 @@ +from .activation_function import * +from .convolution import * +from .embedding import * +from .linear import * +from .normalization import * +from .pooling import * \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py new file mode 100644 index 000000000..ed2f4bcaf --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py @@ -0,0 +1,11 @@ +import torch +from ..registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.ReLU) +@meta_patched_module.register(torch.nn.Sigmoid) +@meta_patched_module.register(torch.nn.GELU) +@meta_patched_module.register(torch.nn.Tanh) +@meta_patched_module.register(torch.nn.ReLU6) +def torch_nn_non_linear_act(self, input): + return torch.empty(input.shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py new file mode 100644 index 000000000..b600f4df2 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py @@ -0,0 +1,57 @@ +import math +import torch +from ..registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Conv1d) +def torch_nn_conv1d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d + l_in = input.shape[-1] + c_out = self.out_channels + l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.Conv2d) +def torch_nn_conv2d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d + h_in, w_in = input.shape[-2:] + c_out = self.out_channels + h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') + + +@meta_patched_module.register(torch.nn.Conv3d) +def torch_nn_conv3d(self, input): + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d + d_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * + (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + return torch.empty(result_shape, device='meta') diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py new file mode 100644 index 000000000..705d37735 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py @@ -0,0 +1,8 @@ +import torch +from ..registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Embedding) +def torch_nn_embedding(self, input): + result_shape = input.shape + (self.embedding_dim,) + return torch.empty(result_shape, device='meta') \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py new file mode 100644 index 000000000..1f22ffd60 --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py @@ -0,0 +1,9 @@ +import torch +from ..registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.Linear) +def torch_nn_linear(self, input): + last_dim = input.shape[-1] + assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' + return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py new file mode 100644 index 000000000..78a3620cc --- /dev/null +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -0,0 +1,20 @@ +import torch +from ..registry import meta_patched_module + + +@meta_patched_module.register(torch.nn.LayerNorm) +@meta_patched_module.register(torch.nn.GroupNorm) +@meta_patched_module.register(torch.nn.BatchNorm1d) +@meta_patched_module.register(torch.nn.BatchNorm2d) +@meta_patched_module.register(torch.nn.BatchNorm3d) +def torch_nn_normalize(self, input): + # check shape + if isinstance(self, torch.nn.BatchNorm1d): + assert input.dim() in [2, 3] + elif isinstance(self, torch.nn.BatchNorm2d): + assert input.dim() == 4 + elif isinstance(self, torch.nn.BatchNorm3d): + assert input.dim() == 5 + + # normalization maintain the same shape as the input + return input.clone() \ No newline at end of file diff --git a/colossalai/fx/tracer/meta_patch/patched_module.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py similarity index 63% rename from colossalai/fx/tracer/meta_patch/patched_module.py rename to colossalai/fx/tracer/meta_patch/patched_module/pooling.py index 787e7e68b..a336120f5 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py @@ -1,91 +1,6 @@ import math import torch -from .registry import meta_patched_module - - -@meta_patched_module.register(torch.nn.Linear) -def torch_nn_linear(self, input): - last_dim = input.shape[-1] - assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' - return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") - - -@meta_patched_module.register(torch.nn.LayerNorm) -@meta_patched_module.register(torch.nn.GroupNorm) -@meta_patched_module.register(torch.nn.BatchNorm1d) -@meta_patched_module.register(torch.nn.BatchNorm2d) -@meta_patched_module.register(torch.nn.BatchNorm3d) -def torch_nn_normalize(self, input): - # check shape - if isinstance(self, torch.nn.BatchNorm1d): - assert input.dim() in [2, 3] - elif isinstance(self, torch.nn.BatchNorm2d): - assert input.dim() == 4 - elif isinstance(self, torch.nn.BatchNorm3d): - assert input.dim() == 5 - - # normalization maintain the same shape as the input - return input.clone() - - -@meta_patched_module.register(torch.nn.Embedding) -def torch_nn_embedding(self, input): - result_shape = input.shape + (self.embedding_dim,) - return torch.empty(result_shape, device='meta') - - -@meta_patched_module.register(torch.nn.Conv1d) -def torch_nn_conv1d(self, input): - # the output shape is calculated using the formula stated - # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d - l_in = input.shape[-1] - c_out = self.out_channels - l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - result_shape = input.shape[:-2] + ( - c_out, - l_out, - ) - return torch.empty(result_shape, device='meta') - - -@meta_patched_module.register(torch.nn.Conv2d) -def torch_nn_conv2d(self, input): - # the output shape is calculated using the formula stated - # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d - h_in, w_in = input.shape[-2:] - c_out = self.out_channels - h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) - result_shape = input.shape[:-3] + ( - c_out, - h_out, - w_out, - ) - return torch.empty(result_shape, device='meta') - - -@meta_patched_module.register(torch.nn.Conv3d) -def torch_nn_conv3d(self, input): - # the output shape is calculated using the formula stated - # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d - d_in, h_in, w_in = input.shape[-3:] - c_out = self.out_channels - d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) - w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * - (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) - result_shape = input.shape[:-4] + ( - c_out, - d_out, - h_out, - w_out, - ) - return torch.empty(result_shape, device='meta') +from ..registry import meta_patched_module @meta_patched_module.register(torch.nn.AvgPool1d) @@ -274,10 +189,4 @@ def torch_nn_adapative_pooling_3d(self, input): self.output_size, self.output_size, ) - return torch.empty(result_shape, device='meta') - - -@meta_patched_module.register(torch.nn.ReLU) -@meta_patched_module.register(torch.nn.ReLU6) -def torch_nn_func_relu(self, input): - return torch.empty(input.shape, device='meta') + return torch.empty(result_shape, device='meta') \ No newline at end of file