mirror of https://github.com/hpcaitech/ColossalAI
[fx] hack __torch_dispatch__ for meta tensor and autograd. (#1515)
* [fx] hack __torch_dispatch__ for meta tensor and autograd. * [fx] hack __torch_dispatch__ for meta tensor and autograd. * [fx] hack __torch_dispatch__ for meta tensor and autograd. * [fx] hack __torch_dispatch__ for meta tensor and autograd. * [fx] hack __torch_dispatch__ for meta tensor and autograd. * [fx] add bad case detections. * [fx] add bad case detections. * [fx] rename MetaTensor attributes. * [fx] fix unexpected error. * [fx] fix unexpected error. * [fx] fix unexpected error. * [fx] fix unexpected error. * [fx] fix unexpected error. * [fx] add register backward for native_batch_norm_backward. * [fx] add more meta backend support for nn.Modules. * [fx] add meta backend to support timm and torchvision models. * [fx] add meta hardswish for timm models.pull/1530/head
parent
4537d39df9
commit
5cc849f6ce
|
@ -1,12 +1,13 @@
|
|||
from operator import add, getitem
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node, map_aggregate, Argument, Target
|
||||
from torch.fx.node import Node, Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
||||
from functools import reduce
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from colossalai.fx.profiler import MetaProfile, profile_function, profile_module, calculate_activation_size, profile_method
|
||||
from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
@ -75,9 +76,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
"""
|
||||
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
|
||||
"""
|
||||
for elem in args:
|
||||
if isinstance(elem, torch.Tensor):
|
||||
assert elem.is_meta, "Input torch.Tensor are assumed to appear with device='meta'"
|
||||
args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args)
|
||||
return super().run(*args, initial_env, enable_io_processing)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
@ -103,7 +102,7 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
else:
|
||||
return TensorMetadata(None, None, False, None, 0, False)
|
||||
|
||||
meta = map_aggregate(result, extract_tensor_meta)
|
||||
meta = tree_map(extract_tensor_meta, result)
|
||||
n.meta['tensor_meta'] = meta
|
||||
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
from .registry import *
|
||||
try:
|
||||
from ._meta_registrations import *
|
||||
except:
|
||||
import torch
|
||||
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
|
||||
from .meta_tensor import MetaTensor
|
||||
from .registry import meta_profiler_function, meta_profiler_module
|
||||
from .profiler_function import *
|
||||
from .profiler_module import *
|
||||
from .profiler import *
|
||||
|
|
|
@ -0,0 +1,339 @@
|
|||
# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
|
||||
# should be activated for PyTorch version 1.12.0 and below
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||
|
||||
meta_table = {}
|
||||
|
||||
|
||||
def register_meta(op, register_dispatcher=True):
|
||||
def wrapper(f):
|
||||
def add_func(op):
|
||||
meta_table[op] = f
|
||||
if register_dispatcher:
|
||||
name = (
|
||||
op.__name__
|
||||
if op._overloadname != "default"
|
||||
else op.overloadpacket.__name__
|
||||
)
|
||||
meta_lib.impl(name, f)
|
||||
|
||||
tree_map(add_func, op)
|
||||
return f
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/pull/79834
|
||||
@register_meta(aten.convolution.default)
|
||||
def meta_conv(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
):
|
||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
Args:
|
||||
ln: length of the dimension
|
||||
p: padding in that dim
|
||||
d: dilation in that dim
|
||||
k: kernel size in that dim
|
||||
s: stride in that dim
|
||||
Returns:
|
||||
The output length
|
||||
"""
|
||||
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
|
||||
|
||||
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
if transposed convolution is used.
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
Args:
|
||||
ln: length of the dimension
|
||||
p: padding in that dim
|
||||
d: dilation in that dim
|
||||
k: kernel size in that dim
|
||||
s: stride in that dim
|
||||
op: output padding in that dim
|
||||
Returns:
|
||||
The output length
|
||||
"""
|
||||
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
|
||||
|
||||
def calc_conv_nd_return_shape(
|
||||
dims: torch.Size,
|
||||
kernel_size: torch.Size,
|
||||
stride: Union[List[int], int],
|
||||
padding: Union[List[int], int],
|
||||
dilation: Union[List[int], int],
|
||||
output_padding: Optional[Union[List[int], int]] = None,
|
||||
):
|
||||
ret_shape = []
|
||||
if isinstance(stride, int):
|
||||
stride = [stride] * len(dims)
|
||||
elif len(stride) == 1:
|
||||
stride = [stride[0]] * len(dims)
|
||||
|
||||
if isinstance(padding, int):
|
||||
padding = [padding] * len(dims)
|
||||
elif len(padding) == 1:
|
||||
padding = [padding[0]] * len(dims)
|
||||
|
||||
if isinstance(dilation, int):
|
||||
dilation = [dilation] * len(dims)
|
||||
elif len(dilation) == 1:
|
||||
dilation = [dilation[0]] * len(dims)
|
||||
|
||||
output_padding_list: Optional[List[int]] = None
|
||||
if output_padding:
|
||||
if isinstance(output_padding, int):
|
||||
output_padding_list = [output_padding] * len(dims)
|
||||
elif len(output_padding) == 1:
|
||||
output_padding_list = [output_padding[0]] * len(dims)
|
||||
else:
|
||||
output_padding_list = output_padding
|
||||
|
||||
for i in range(len(dims)):
|
||||
# If output_padding is present, we are dealing with a transposed convolution
|
||||
if output_padding_list:
|
||||
ret_shape.append(
|
||||
_formula_transposed(
|
||||
dims[i],
|
||||
padding[i],
|
||||
dilation[i],
|
||||
kernel_size[i],
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
)
|
||||
)
|
||||
else:
|
||||
ret_shape.append(
|
||||
_formula(
|
||||
dims[i], padding[i], dilation[i], kernel_size[i], stride[i]
|
||||
)
|
||||
)
|
||||
return ret_shape
|
||||
|
||||
def pick_memory_format():
|
||||
if input_tensor.is_contiguous(memory_format=torch.channels_last):
|
||||
return torch.channels_last
|
||||
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
||||
return torch.contiguous_format
|
||||
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
||||
return torch.preserve_format
|
||||
|
||||
kernel_size = weight.shape[2:]
|
||||
dims = input_tensor.shape[2:]
|
||||
if is_transposed:
|
||||
out_channels = groups * weight.shape[1]
|
||||
|
||||
shape_out = calc_conv_nd_return_shape(
|
||||
dims,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
output_padding,
|
||||
)
|
||||
|
||||
else:
|
||||
out_channels = weight.shape[0]
|
||||
if weight.shape[1] != input_tensor.shape[1] / groups:
|
||||
raise RuntimeError("Invalid channel dimensions")
|
||||
shape_out = calc_conv_nd_return_shape(
|
||||
dims, kernel_size, stride, padding, dilation
|
||||
)
|
||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||
mem_fmt = pick_memory_format()
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
return out
|
||||
|
||||
|
||||
@register_meta(aten.convolution_backward.default)
|
||||
def meta_conv_backward(
|
||||
grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask
|
||||
):
|
||||
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
|
||||
|
||||
|
||||
@register_meta(aten.relu.default)
|
||||
def meta_relu(input: torch.Tensor):
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@register_meta(aten.hardswish.default)
|
||||
def meta_hardswish(input: torch.Tensor):
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@register_meta(aten.hardswish_backward.default)
|
||||
def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor):
|
||||
grad_in = torch.empty_like(input)
|
||||
return grad_in
|
||||
|
||||
|
||||
@register_meta([aten.roll.default, ])
|
||||
def meta_roll(input:torch.Tensor, shifts, dims):
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@register_meta(aten.native_batch_norm.default)
|
||||
def meta_bn(
|
||||
input: torch.Tensor,
|
||||
weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
n_input = input.size(1)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
running_mean = torch.empty((n_input), device='meta')
|
||||
running_var = torch.empty((n_input), device='meta')
|
||||
return output, running_mean, running_var
|
||||
|
||||
|
||||
@register_meta(aten.native_batch_norm_backward.default)
|
||||
def meta_bn_backward(
|
||||
dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
running_mean, running_var, save_mean, save_invstd, train, eps, output_mask
|
||||
):
|
||||
dX = torch.empty_like(input)
|
||||
dgamma = torch.empty_like(weight)
|
||||
dbeta = torch.empty_like(weight)
|
||||
return dX, dgamma, dbeta
|
||||
|
||||
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(
|
||||
input: torch.Tensor,
|
||||
normalized_shape, weight, bias, eps
|
||||
):
|
||||
n_input = input.size(1)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
running_mean = torch.empty((n_input), device='meta')
|
||||
running_var = torch.empty((n_input), device='meta')
|
||||
return output, running_mean, running_var
|
||||
|
||||
|
||||
@register_meta(aten.native_layer_norm_backward.default)
|
||||
def meta_ln_backward(
|
||||
dY: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
normalized_shape, mean, rstd, weight, bias, grad_input_mask
|
||||
):
|
||||
dX = torch.empty_like(input)
|
||||
dgamma = torch.empty_like(weight)
|
||||
dbeta = torch.empty_like(bias)
|
||||
return dX, dgamma, dbeta
|
||||
|
||||
|
||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||
def meta_adaptive_avg_pool2d_backward(
|
||||
grad_output: torch.Tensor, input: torch.Tensor,
|
||||
):
|
||||
grad_input = torch.empty_like(input)
|
||||
return torch.empty_like(input)
|
||||
|
||||
|
||||
@register_meta(aten.index.Tensor)
|
||||
def meta_index_Tensor(self, indices):
|
||||
assert indices, "at least one index must be provided"
|
||||
# aten::index is the internal advanced indexing implementation
|
||||
# checkIndexTensorTypes and expandTensors
|
||||
result: List[Optional[torch.Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||
"tensors used as indices must be long, byte or bool tensors"
|
||||
if index.dtype in [torch.int8, torch.bool]:
|
||||
nonzero = index.nonzero()
|
||||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
else:
|
||||
result.append(index)
|
||||
else:
|
||||
result.append(index)
|
||||
indices = result
|
||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
# expand_outplace
|
||||
import torch._refs as refs # avoid import cycle in mypy
|
||||
|
||||
indices = list(refs._maybe_broadcast(*indices))
|
||||
# add missing null tensors
|
||||
while len(indices) < self.ndim:
|
||||
indices.append(None)
|
||||
|
||||
# hasContiguousSubspace
|
||||
# true if all non-null tensors are adjacent
|
||||
# See:
|
||||
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
||||
state = 0
|
||||
has_contiguous_subspace = False
|
||||
for index in indices:
|
||||
if state == 0:
|
||||
if index is not None:
|
||||
state = 1
|
||||
elif state == 1:
|
||||
if index is None:
|
||||
state = 2
|
||||
else:
|
||||
if index is not None:
|
||||
break
|
||||
else:
|
||||
has_contiguous_subspace = True
|
||||
|
||||
# transposeToFront
|
||||
# This is the logic that causes the newly inserted dimensions to show up
|
||||
# at the beginning of the tensor, if they're not contiguous
|
||||
if not has_contiguous_subspace:
|
||||
dims = []
|
||||
transposed_indices = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
for i, index in enumerate(indices):
|
||||
if index is None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
self = self.permute(dims)
|
||||
indices = transposed_indices
|
||||
|
||||
# AdvancedIndex::AdvancedIndex
|
||||
# Now we can assume the indices have contiguous subspace
|
||||
# This is simplified from AdvancedIndex which goes to more effort
|
||||
# to put the input and indices in a form so that TensorIterator can
|
||||
# take them. If we write a ref for this, probably that logic should
|
||||
# get implemented
|
||||
before_shape: List[int] = []
|
||||
after_shape: List[int] = []
|
||||
replacement_shape: List[int] = []
|
||||
for dim, index in enumerate(indices):
|
||||
if index is None:
|
||||
if replacement_shape:
|
||||
after_shape.append(self.shape[dim])
|
||||
else:
|
||||
before_shape.append(self.shape[dim])
|
||||
else:
|
||||
replacement_shape = list(index.shape)
|
||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
|
@ -0,0 +1,50 @@
|
|||
import torch
|
||||
from torch.utils._pytree import tree_map, tree_flatten
|
||||
|
||||
|
||||
__all__ = ['MetaTensor']
|
||||
|
||||
|
||||
class MetaTensor(torch.Tensor):
|
||||
"""
|
||||
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
||||
"""
|
||||
|
||||
_tensor: torch.Tensor
|
||||
|
||||
__slots__ = ['_tensor']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
# The wrapping tensor (MetaTensor) shouldn't hold any
|
||||
# memory for the class in question, but it should still
|
||||
# advertise the same device as before
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls, elem.size(),
|
||||
strides=elem.stride(), storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype, layout=elem.layout,
|
||||
device='cpu', requires_grad=elem.requires_grad
|
||||
) # deceive the frontend for aten selections
|
||||
r._tensor = elem
|
||||
# ...the real tensor is held as an element on the tensor.
|
||||
return r
|
||||
|
||||
@ classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def unwrap(x):
|
||||
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||||
x = MetaTensor(x)
|
||||
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
|
||||
|
||||
args = tree_map(unwrap, args)
|
||||
kwargs = tree_map(unwrap, kwargs)
|
||||
|
||||
# run aten for backend=CPU but actually on backend=Meta
|
||||
out = func(*args, **kwargs)
|
||||
|
||||
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||||
# our custom tensor subclass
|
||||
def wrap(x):
|
||||
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
return tree_map(wrap, out)
|
|
@ -1,10 +1,8 @@
|
|||
from functools import partial
|
||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||
from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union
|
||||
import torch
|
||||
from torch.fx.node import Argument, Target, map_aggregate
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.fx._compatibility import compatibility
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module
|
||||
from . import meta_profiler_function, meta_profiler_module
|
||||
|
||||
__all__ = [
|
||||
|
@ -58,6 +56,10 @@ INPLACE_METHOD = [
|
|||
'reshape',
|
||||
'dim',
|
||||
'flatten',
|
||||
'size',
|
||||
'view',
|
||||
'unsqueeze',
|
||||
'to',
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are not inplace here
|
||||
|
@ -137,8 +139,6 @@ def profile_function(target: 'Target') -> Callable:
|
|||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
assert meta_profiler_function.has(target) or meta_profiler_function.has(
|
||||
target.__name__), CALL_FUNCTION_MSG.format(target)
|
||||
# ensure all arguments satisfy `device='meta'`
|
||||
args, kwargs = map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
|
||||
|
||||
# call_function has no parameters
|
||||
param_size = 0
|
||||
|
@ -154,13 +154,7 @@ def profile_function(target: 'Target') -> Callable:
|
|||
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||
|
||||
f.__name__ = target.__name__
|
||||
# fetch patched function
|
||||
if meta_patched_function.has(target):
|
||||
func = meta_patched_function.get(target)
|
||||
elif meta_patched_function.has(target.__name__):
|
||||
func = meta_patched_function.get(target.__name__)
|
||||
else:
|
||||
func = target
|
||||
func = target
|
||||
return f
|
||||
|
||||
|
||||
|
@ -180,8 +174,6 @@ def profile_method(target: 'Target') -> Callable:
|
|||
# execute the method and return the result
|
||||
assert isinstance(target, str), f'{target} instance is not str.'
|
||||
|
||||
# ensure all arguments satisfy `device='meta'`
|
||||
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
|
||||
result = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
|
||||
target, INPLACE_METHOD, NON_INPLACE_METHOD)
|
||||
|
@ -216,8 +208,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
|||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
|
||||
# ensure all arguments satisfy `device='meta'`
|
||||
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
|
||||
|
||||
# only `nn.Module` has parameters
|
||||
param_size = calculate_param_size(module)
|
||||
activation_size = 0
|
||||
result = func(*args, **kwargs)
|
||||
|
@ -228,9 +220,5 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
|||
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||
|
||||
f.__name__ = module.__class__.__name__
|
||||
# fetch patched module
|
||||
if meta_patched_module.has(type(module)):
|
||||
func = partial(meta_patched_module.get(type(module)), module)
|
||||
else:
|
||||
func = module.forward
|
||||
func = module.forward
|
||||
return f
|
||||
|
|
Loading…
Reference in New Issue