[fx] hack __torch_dispatch__ for meta tensor and autograd. (#1515)

* [fx] hack __torch_dispatch__ for meta tensor and autograd.

* [fx] hack __torch_dispatch__ for meta tensor and autograd.

* [fx] hack __torch_dispatch__ for meta tensor and autograd.

* [fx] hack __torch_dispatch__ for meta tensor and autograd.

* [fx] hack __torch_dispatch__ for meta tensor and autograd.

* [fx] add bad case detections.

* [fx] add bad case detections.

* [fx] rename MetaTensor attributes.

* [fx] fix unexpected error.

* [fx] fix unexpected error.

* [fx] fix unexpected error.

* [fx] fix unexpected error.

* [fx] fix unexpected error.

* [fx] add register backward for native_batch_norm_backward.

* [fx] add more meta backend support for nn.Modules.

* [fx] add meta backend to support timm and torchvision models.

* [fx] add meta hardswish for timm models.
pull/1530/head
Super Daniel 2022-08-31 16:30:16 +08:00 committed by GitHub
parent 4537d39df9
commit 5cc849f6ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 410 additions and 28 deletions

View File

@ -1,12 +1,13 @@
from operator import add, getitem
import torch
import torch.fx
from torch.fx.node import Node, map_aggregate, Argument, Target
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list
from colossalai.fx.profiler import MetaProfile, profile_function, profile_module, calculate_activation_size, profile_method
from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method
@compatibility(is_backward_compatible=True)
@ -75,9 +76,7 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
"""
for elem in args:
if isinstance(elem, torch.Tensor):
assert elem.is_meta, "Input torch.Tensor are assumed to appear with device='meta'"
args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args)
return super().run(*args, initial_env, enable_io_processing)
@compatibility(is_backward_compatible=True)
@ -103,7 +102,7 @@ class MetaInfoProp(torch.fx.Interpreter):
else:
return TensorMetadata(None, None, False, None, 0, False)
meta = map_aggregate(result, extract_tensor_meta)
meta = tree_map(extract_tensor_meta, result)
n.meta['tensor_meta'] = meta
# TODO: the attribute node_size should be removed in the future

View File

@ -1,4 +1,10 @@
from .registry import *
try:
from ._meta_registrations import *
except:
import torch
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .meta_tensor import MetaTensor
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
from .profiler_module import *
from .profiler import *

View File

@ -0,0 +1,339 @@
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
# should be activated for PyTorch version 1.12.0 and below
from typing import List, Optional, Tuple, Union
import torch
from torch.utils._pytree import tree_map
aten = torch.ops.aten
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
meta_table = {}
def register_meta(op, register_dispatcher=True):
def wrapper(f):
def add_func(op):
meta_table[op] = f
if register_dispatcher:
name = (
op.__name__
if op._overloadname != "default"
else op.overloadpacket.__name__
)
meta_lib.impl(name, f)
tree_map(add_func, op)
return f
return wrapper
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
def meta_conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
Returns:
The output length
"""
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
if transposed convolution is used.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
op: output padding in that dim
Returns:
The output length
"""
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
def calc_conv_nd_return_shape(
dims: torch.Size,
kernel_size: torch.Size,
stride: Union[List[int], int],
padding: Union[List[int], int],
dilation: Union[List[int], int],
output_padding: Optional[Union[List[int], int]] = None,
):
ret_shape = []
if isinstance(stride, int):
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)
if isinstance(padding, int):
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)
if isinstance(dilation, int):
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
output_padding_list: Optional[List[int]] = None
if output_padding:
if isinstance(output_padding, int):
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
else:
output_padding_list = output_padding
for i in range(len(dims)):
# If output_padding is present, we are dealing with a transposed convolution
if output_padding_list:
ret_shape.append(
_formula_transposed(
dims[i],
padding[i],
dilation[i],
kernel_size[i],
stride[i],
output_padding_list[i],
)
)
else:
ret_shape.append(
_formula(
dims[i], padding[i], dilation[i], kernel_size[i], stride[i]
)
)
return ret_shape
def pick_memory_format():
if input_tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format
kernel_size = weight.shape[2:]
dims = input_tensor.shape[2:]
if is_transposed:
out_channels = groups * weight.shape[1]
shape_out = calc_conv_nd_return_shape(
dims,
kernel_size,
stride,
padding,
dilation,
output_padding,
)
else:
out_channels = weight.shape[0]
if weight.shape[1] != input_tensor.shape[1] / groups:
raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape(
dims, kernel_size, stride, padding, dilation
)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(
grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask
):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
@register_meta(aten.relu.default)
def meta_relu(input: torch.Tensor):
return torch.empty_like(input)
@register_meta(aten.hardswish.default)
def meta_hardswish(input: torch.Tensor):
return torch.empty_like(input)
@register_meta(aten.hardswish_backward.default)
def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor):
grad_in = torch.empty_like(input)
return grad_in
@register_meta([aten.roll.default, ])
def meta_roll(input:torch.Tensor, shifts, dims):
return torch.empty_like(input)
@register_meta(aten.native_batch_norm.default)
def meta_bn(
input: torch.Tensor,
weight, bias, running_mean, running_var, training, momentum, eps
):
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
return output, running_mean, running_var
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(
dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
running_mean, running_var, save_mean, save_invstd, train, eps, output_mask
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight)
return dX, dgamma, dbeta
@register_meta(aten.native_layer_norm.default)
def meta_ln(
input: torch.Tensor,
normalized_shape, weight, bias, eps
):
n_input = input.size(1)
output = torch.empty_like(input)
running_mean = torch.empty((n_input), device='meta')
running_var = torch.empty((n_input), device='meta')
return output, running_mean, running_var
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(
dY: torch.Tensor,
input: torch.Tensor,
normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
dX = torch.empty_like(input)
dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias)
return dX, dgamma, dbeta
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor, input: torch.Tensor,
):
grad_input = torch.empty_like(input)
return torch.empty_like(input)
@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
assert indices, "at least one index must be provided"
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs # avoid import cycle in mypy
indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
while len(indices) < self.ndim:
indices.append(None)
# hasContiguousSubspace
# true if all non-null tensors are adjacent
# See:
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
state = 0
has_contiguous_subspace = False
for index in indices:
if state == 0:
if index is not None:
state = 1
elif state == 1:
if index is None:
state = 2
else:
if index is not None:
break
else:
has_contiguous_subspace = True
# transposeToFront
# This is the logic that causes the newly inserted dimensions to show up
# at the beginning of the tensor, if they're not contiguous
if not has_contiguous_subspace:
dims = []
transposed_indices = []
for i, index in enumerate(indices):
if index is not None:
dims.append(i)
transposed_indices.append(index)
for i, index in enumerate(indices):
if index is None:
dims.append(i)
transposed_indices.append(index)
self = self.permute(dims)
indices = transposed_indices
# AdvancedIndex::AdvancedIndex
# Now we can assume the indices have contiguous subspace
# This is simplified from AdvancedIndex which goes to more effort
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
after_shape.append(self.shape[dim])
else:
before_shape.append(self.shape[dim])
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)

View File

@ -0,0 +1,50 @@
import torch
from torch.utils._pytree import tree_map, tree_flatten
__all__ = ['MetaTensor']
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
"""
_tensor: torch.Tensor
__slots__ = ['_tensor']
@staticmethod
def __new__(cls, elem):
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass(
cls, elem.size(),
strides=elem.stride(), storage_offset=elem.storage_offset(),
dtype=elem.dtype, layout=elem.layout,
device='cpu', requires_grad=elem.requires_grad
) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
return r
@ classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaTensor(x)
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)

View File

@ -1,10 +1,8 @@
from functools import partial
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union
import torch
from torch.fx.node import Argument, Target, map_aggregate
from torch.fx.node import Argument, Target
from torch.fx._compatibility import compatibility
from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module
from . import meta_profiler_function, meta_profiler_module
__all__ = [
@ -58,6 +56,10 @@ INPLACE_METHOD = [
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
]
# TODO: list all call_methods that are not inplace here
@ -137,8 +139,6 @@ def profile_function(target: 'Target') -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target)
# ensure all arguments satisfy `device='meta'`
args, kwargs = map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
# call_function has no parameters
param_size = 0
@ -154,13 +154,7 @@ def profile_function(target: 'Target') -> Callable:
return result, MetaProfile(param_size, activation_size, flops, macs)
f.__name__ = target.__name__
# fetch patched function
if meta_patched_function.has(target):
func = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
func = meta_patched_function.get(target.__name__)
else:
func = target
func = target
return f
@ -180,8 +174,6 @@ def profile_method(target: 'Target') -> Callable:
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
# ensure all arguments satisfy `device='meta'`
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
result = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
@ -216,8 +208,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
# ensure all arguments satisfy `device='meta'`
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
# only `nn.Module` has parameters
param_size = calculate_param_size(module)
activation_size = 0
result = func(*args, **kwargs)
@ -228,9 +220,5 @@ def profile_module(module: torch.nn.Module) -> Callable:
return result, MetaProfile(param_size, activation_size, flops, macs)
f.__name__ = module.__class__.__name__
# fetch patched module
if meta_patched_module.has(type(module)):
func = partial(meta_patched_module.get(type(module)), module)
else:
func = module.forward
func = module.forward
return f