[FX] refactor experimental tracer and adapt it with hf models (#3157)

* pass gpt trace and meta_prop

* pass t5 trace and meta_prop

* [FX] refactor experimental tracer and adapt it with hf models

* pass all mainstream model zoo

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* skip tests

* fix CI

* using packaging version

* polish
pull/3197/head
YuliangLiu0306 2023-03-22 10:40:33 +08:00 committed by GitHub
parent b429529365
commit f57d34958b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1014 additions and 863 deletions

View File

@ -6,11 +6,15 @@
from typing import Callable, List, Optional, Tuple, Union
import torch
from packaging import version
from torch.utils._pytree import tree_map
aten = torch.ops.aten
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
try:
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
except AttributeError:
meta_lib = None
meta_table = {}
@ -50,432 +54,411 @@ def register_meta(op, register_dispatcher=True):
return wrapper
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
def meta_conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
Returns:
The output length
"""
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
if transposed convolution is used.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
op: output padding in that dim
Returns:
The output length
"""
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
def calc_conv_nd_return_shape(
dims: torch.Size,
kernel_size: torch.Size,
stride: Union[List[int], int],
padding: Union[List[int], int],
dilation: Union[List[int], int],
output_padding: Optional[Union[List[int], int]] = None,
if version.parse(torch.__version__) >= version.parse('1.12.0'):
# ============================== Convolutions ======================================
# https://github.com/pytorch/pytorch/pull/79834
@register_meta(aten.convolution.default)
def meta_conv(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
):
ret_shape = []
if isinstance(stride, int):
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)
if isinstance(padding, int):
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
Returns:
The output length
"""
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
if isinstance(dilation, int):
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
"""
Formula to apply to calculate the length of some dimension of the output
if transposed convolution is used.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
Args:
ln: length of the dimension
p: padding in that dim
d: dilation in that dim
k: kernel size in that dim
s: stride in that dim
op: output padding in that dim
Returns:
The output length
"""
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
output_padding_list: Optional[List[int]] = None
if output_padding:
if isinstance(output_padding, int):
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
else:
output_padding_list = output_padding
def calc_conv_nd_return_shape(
dims: torch.Size,
kernel_size: torch.Size,
stride: Union[List[int], int],
padding: Union[List[int], int],
dilation: Union[List[int], int],
output_padding: Optional[Union[List[int], int]] = None,
):
ret_shape = []
if isinstance(stride, int):
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)
for i in range(len(dims)):
# If output_padding is present, we are dealing with a transposed convolution
if output_padding_list:
ret_shape.append(
_formula_transposed(
dims[i],
padding[i],
dilation[i],
kernel_size[i],
stride[i],
output_padding_list[i],
))
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
if isinstance(padding, int):
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)
def pick_memory_format():
if input_tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format
if isinstance(dilation, int):
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
kernel_size = weight.shape[2:]
dims = input_tensor.shape[2:]
if is_transposed:
out_channels = groups * weight.shape[1]
output_padding_list: Optional[List[int]] = None
if output_padding:
if isinstance(output_padding, int):
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
else:
output_padding_list = output_padding
shape_out = calc_conv_nd_return_shape(
dims,
kernel_size,
stride,
padding,
dilation,
output_padding,
)
for i in range(len(dims)):
# If output_padding is present, we are dealing with a transposed convolution
if output_padding_list:
ret_shape.append(
_formula_transposed(
dims[i],
padding[i],
dilation[i],
kernel_size[i],
stride[i],
output_padding_list[i],
))
else:
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
return ret_shape
else:
out_channels = weight.shape[0]
if weight.shape[1] != input_tensor.shape[1] / groups:
raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
def pick_memory_format():
if input_tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
return torch.contiguous_format
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
return torch.preserve_format
kernel_size = weight.shape[2:]
dims = input_tensor.shape[2:]
if is_transposed:
out_channels = groups * weight.shape[1]
@register_meta(aten._convolution.default)
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
*extra_args):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
shape_out = calc_conv_nd_return_shape(
dims,
kernel_size,
stride,
padding,
dilation,
output_padding,
)
else:
out_channels = weight.shape[0]
if weight.shape[1] != input_tensor.shape[1] / groups:
raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
return new_like(input), new_like(weight), new((bias_sizes))
@register_meta(aten._convolution.default)
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
*extra_args):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):
return new_like(input), new_like(weight), new((bias_sizes))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
):
return new_like(input)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
):
return new_like(input)
# ================================ RNN =============================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn.default)
def meta_cuda_rnn(
input,
weight,
weight_stride0,
weight_buf,
hx,
cx,
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
dropout,
train,
bidirectional,
batch_sizes,
dropout_state,
):
# ================================ RNN =============================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn.default)
def meta_cuda_rnn(
input,
weight,
weight_stride0,
weight_buf,
hx,
cx,
mode,
hidden_size,
proj_size,
num_layers,
batch_first,
dropout,
train,
bidirectional,
batch_sizes,
dropout_state,
):
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
mini_batch = batch_sizes[0]
batch_sizes_sum = input.shape[0]
else:
seq_length = input.shape[1] if batch_first else input.shape[0]
mini_batch = input.shape[0] if batch_first else input.shape[1]
batch_sizes_sum = -1
is_input_packed = len(batch_sizes) != 0
if is_input_packed:
seq_length = len(batch_sizes)
mini_batch = batch_sizes[0]
batch_sizes_sum = input.shape[0]
else:
seq_length = input.shape[1] if batch_first else input.shape[0]
mini_batch = input.shape[0] if batch_first else input.shape[1]
batch_sizes_sum = -1
num_directions = 2 if bidirectional else 1
out_size = proj_size if proj_size != 0 else hidden_size
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = ([mini_batch, seq_length, out_size *
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
output = input.new_empty(out_shape)
num_directions = 2 if bidirectional else 1
out_size = proj_size if proj_size != 0 else hidden_size
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = ([mini_batch, seq_length, out_size *
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
cy = new(0) if cx is None else cx.new_empty(cell_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
cy = new(0) if cx is None else cx.new_empty(cell_shape)
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
reserve_shape = 0 if train else 0
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
reserve_shape = 0 if train else 0
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
return output, hy, cy, reserve, weight_buf
return output, hy, cy, reserve, weight_buf
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs):
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
()) # (grad_input, grad_weight, grad_hx, grad_cx)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
_unregistered_ewise = [
aten.relu.default,
aten.prelu.default,
aten.hardswish.default,
aten.hardtanh.default,
aten.prelu_backward.default,
aten.hardswish_backward.default,
aten.hardtanh_backward.default,
]
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@register_meta(aten._cudnn_rnn_backward.default)
def meta_cudnn_rnn_backward(input: torch.Tensor,
weight: torch.Tensor,
weight_stride0: int,
hx: torch.Tensor,
cx: Optional[torch.Tensor] = None,
*args,
**kwargs):
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
()) # (grad_input, grad_weight, grad_hx, grad_cx)
@register_meta(_unregistered_ewise)
def meta_unregistered_ewise(input: torch.Tensor, *args):
return new_like(input)
# ============================== Normalization =====================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm.default)
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
return new_like(input), new((n_input)), new((n_input))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
_unregistered_ewise = [
aten.relu.default,
aten.prelu.default,
aten.hardswish.default,
aten.hardtanh.default,
aten.prelu_backward.default,
aten.hardswish_backward.default,
aten.hardtanh_backward.default,
]
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, train, eps, output_mask):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
return new_like(input), new((n_input)), new((n_input)), new(
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
@register_meta(_unregistered_ewise)
def meta_unregistered_ewise(input: torch.Tensor, *args):
return new_like(input)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs, n_input = input.size(0), input.size(1)
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
# ============================== Normalization =====================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm.default)
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
return new_like(input), new((n_input)), new((n_input))
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
# ================================== Misc ==========================================
# Maybe incorrect
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
@register_meta(aten.im2col.default)
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
return new_like(input)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
save_invstd, train, eps, output_mask):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
return out
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return input
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
@register_meta(aten.cudnn_batch_norm.default)
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
n_input = input.size(1)
return new_like(input), new((n_input)), new((n_input)), new(
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
@register_meta(aten._local_scalar_dense.default)
def meta_local_scalar_dense(self: torch.Tensor):
return 0
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return new_like(condition + self + other, dtype=result_type)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
# NB: CuDNN only implements the backward algorithm for batchnorm
# in training mode (evaluation mode batchnorm has a different algorithm),
# which is why this doesn't accept a 'training' parameter.
@register_meta(aten.cudnn_batch_norm_backward.default)
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
save_mean, save_invstd, eps, reserve):
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm.default)
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
bs, n_input = input.size(0), input.size(1)
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
@register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
grad_input_mask):
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
# ================================== Misc ==========================================
# Maybe incorrect
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
@register_meta(aten.im2col.default)
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
return new_like(input)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
return out
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return input
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
@register_meta(aten._local_scalar_dense.default)
def meta_local_scalar_dense(self: torch.Tensor):
return 0
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
result_type = torch.result_type(self, other)
return new_like(condition + self + other, dtype=result_type)
@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
assert indices, "at least one index must be provided"
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
assert indices, "at least one index must be provided"
# aten::index is the internal advanced indexing implementation
# checkIndexTensorTypes and expandTensors
result: List[Optional[torch.Tensor]] = []
for i, index in enumerate(indices):
if index is not None:
assert index.dtype in [torch.long, torch.int8, torch.bool],\
"tensors used as indices must be long, byte or bool tensors"
if index.dtype in [torch.int8, torch.bool]:
nonzero = index.nonzero()
k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim):
assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j))
else:
result.append(index)
else:
result.append(index)
else:
result.append(index)
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs
indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
while len(indices) < self.ndim:
indices.append(None)
indices = list(refs._maybe_broadcast(*indices))
# add missing null tensors
while len(indices) < self.ndim:
indices.append(None)
# hasContiguousSubspace
# true if all non-null tensors are adjacent
# See:
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
state = 0
has_contiguous_subspace = False
for index in indices:
if state == 0:
if index is not None:
state = 1
elif state == 1:
if index is None:
state = 2
else:
if index is not None:
break
else:
has_contiguous_subspace = True
# transposeToFront
# This is the logic that causes the newly inserted dimensions to show up
# at the beginning of the tensor, if they're not contiguous
if not has_contiguous_subspace:
dims = []
transposed_indices = []
for i, index in enumerate(indices):
if index is not None:
dims.append(i)
transposed_indices.append(index)
for i, index in enumerate(indices):
if index is None:
dims.append(i)
transposed_indices.append(index)
self = self.permute(dims)
indices = transposed_indices
# AdvancedIndex::AdvancedIndex
# Now we can assume the indices have contiguous subspace
# This is simplified from AdvancedIndex which goes to more effort
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
after_shape.append(self.shape[dim])
# hasContiguousSubspace
# true if all non-null tensors are adjacent
# See:
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
state = 0
has_contiguous_subspace = False
for index in indices:
if state == 0:
if index is not None:
state = 1
elif state == 1:
if index is None:
state = 2
else:
before_shape.append(self.shape[dim])
if index is not None:
break
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
has_contiguous_subspace = True
# transposeToFront
# This is the logic that causes the newly inserted dimensions to show up
# at the beginning of the tensor, if they're not contiguous
if not has_contiguous_subspace:
dims = []
transposed_indices = []
for i, index in enumerate(indices):
if index is not None:
dims.append(i)
transposed_indices.append(index)
for i, index in enumerate(indices):
if index is None:
dims.append(i)
transposed_indices.append(index)
self = self.permute(dims)
indices = transposed_indices
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
# AdvancedIndex::AdvancedIndex
# Now we can assume the indices have contiguous subspace
# This is simplified from AdvancedIndex which goes to more effort
# to put the input and indices in a form so that TensorIterator can
# take them. If we write a ref for this, probably that logic should
# get implemented
before_shape: List[int] = []
after_shape: List[int] = []
replacement_shape: List[int] = []
for dim, index in enumerate(indices):
if index is None:
if replacement_shape:
after_shape.append(self.shape[dim])
else:
before_shape.append(self.shape[dim])
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return new_like(grad) # (grad_in)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return new_like(grad) # (grad_in)

View File

@ -1,5 +1,6 @@
import torch
import torch.distributed as dist
from packaging import version
aten = torch.ops.aten
@ -49,40 +50,45 @@ _DistCommMethod = [
"scatter",
]
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [
aten.detach.default,
aten.detach_.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
if version.parse(torch.__version__) >= version.parse('1.12.0'):
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [
aten.detach.default,
aten.detach_.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
_InplaceATen = [
aten.add_.Tensor,
aten.add_.Scalar,
aten.sub_.Tensor,
aten.sub_.Scalar,
aten.mul_.Tensor,
aten.mul_.Scalar,
aten.div_.Tensor,
aten.div_.Scalar,
aten.pow_.Tensor,
aten.pow_.Scalar,
]
_InplaceATen = [
aten.add_.Tensor,
aten.add_.Scalar,
aten.sub_.Tensor,
aten.sub_.Scalar,
aten.mul_.Tensor,
aten.mul_.Scalar,
aten.div_.Tensor,
aten.div_.Scalar,
aten.pow_.Tensor,
aten.pow_.Scalar,
]
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
_MaybeInplaceATen = [
aten.diagonal.default,
aten.expand.default,
aten.select.int,
aten.slice.Tensor,
aten.split.Tensor,
aten.squeeze.default,
aten.permute.default,
aten.unsqueeze.default,
aten.as_strided.default,
]
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
_MaybeInplaceATen = [
aten.diagonal.default,
aten.expand.default,
aten.select.int,
aten.slice.Tensor,
aten.split.Tensor,
aten.squeeze.default,
aten.permute.default,
aten.unsqueeze.default,
aten.as_strided.default,
]
else:
_AliasATen = []
_InplaceATen = []
_MaybeInplaceATen = []

View File

@ -11,6 +11,7 @@ from numbers import Number
from typing import Any, Callable, List, Optional, Union
import torch
from packaging import version
from torch.utils._pytree import tree_map
from .meta_tensor import MetaTensor
@ -403,134 +404,139 @@ def zero_flop_jit(*args):
return 0
flop_mapping = {
if version.parse(torch.__version__) >= version.parse('1.12.0'):
flop_mapping = {
# gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
# convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
# pooling
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.max_pool1d.default: ewise_flop_counter(1, 0),
aten.max_pool2d.default: ewise_flop_counter(1, 0),
aten.max_pool3d.default: ewise_flop_counter(1, 0),
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
aten.embedding.default: ewise_flop_counter(1, 0),
}
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.max_pool1d.default: ewise_flop_counter(1, 0),
aten.max_pool2d.default: ewise_flop_counter(1, 0),
aten.max_pool3d.default: ewise_flop_counter(1, 0),
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
aten.embedding.default: ewise_flop_counter(1, 0),
}
ewise_flop_aten = [
ewise_flop_aten = [
# basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
aten.div_.Tensor,
aten.div.Scalar,
aten.div_.Scalar,
aten.mul.Tensor,
aten.mul.Scalar,
aten.mul_.Tensor,
aten.neg.default,
aten.pow.Tensor_Scalar,
aten.rsub.Scalar,
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
aten.div_.Tensor,
aten.div.Scalar,
aten.div_.Scalar,
aten.mul.Tensor,
aten.mul.Scalar,
aten.mul_.Tensor,
aten.neg.default,
aten.pow.Tensor_Scalar,
aten.rsub.Scalar,
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
# activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
aten.hardtanh.default,
aten.hardtanh_.default,
aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default,
aten.hardsigmoid.default,
aten.gelu.default,
aten.gelu_backward.default,
aten.silu.default,
aten.silu_.default,
aten.silu_backward.default,
aten.sigmoid.default,
aten.sigmoid_backward.default,
aten._softmax.default,
aten._softmax_backward_data.default,
aten.relu_.default,
aten.relu.default,
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
aten.hardtanh.default,
aten.hardtanh_.default,
aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default,
aten.hardsigmoid.default,
aten.gelu.default,
aten.gelu_backward.default,
aten.silu.default,
aten.silu_.default,
aten.silu_backward.default,
aten.sigmoid.default,
aten.sigmoid_backward.default,
aten._softmax.default,
aten._softmax_backward_data.default,
aten.relu_.default,
aten.relu.default,
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
# dropout
aten.native_dropout.default,
aten.native_dropout_backward.default,
aten.native_dropout.default,
aten.native_dropout_backward.default,
# distribution
aten.bernoulli_.float,
aten.bernoulli_.float,
# where
aten.where.self,
]
for op in ewise_flop_aten:
flop_mapping[op] = ewise_flop_counter(1, 0)
aten.where.self,
]
for op in ewise_flop_aten:
flop_mapping[op] = ewise_flop_counter(1, 0)
# fix-me: this will be removed in future
zero_flop_aten = [
aten.as_strided.default,
aten.as_strided_.default,
aten.cat.default,
aten.clone.default,
aten.copy_.default,
aten.detach.default,
aten.expand.default,
aten.empty_like.default,
aten.new_empty.default,
aten.new_empty_strided.default,
aten.ones_like.default,
aten._reshape_alias.default,
aten.select.int,
aten.select_backward.default,
aten.squeeze.dim,
aten.slice.Tensor,
aten.slice_backward.default,
aten.split.Tensor,
aten.permute.default,
aten.t.default,
aten.transpose.int,
aten._to_copy.default,
aten.unsqueeze.default,
aten.unbind.int,
aten._unsafe_view.default,
aten.view.default,
aten.zero_.default,
aten.zeros_like.default,
]
# fix-me: this will be removed in future
zero_flop_aten = [
aten.as_strided.default,
aten.as_strided_.default,
aten.cat.default,
aten.clone.default,
aten.copy_.default,
aten.detach.default,
aten.expand.default,
aten.empty_like.default,
aten.new_empty.default,
aten.new_empty_strided.default,
aten.ones_like.default,
aten._reshape_alias.default,
aten.select.int,
aten.select_backward.default,
aten.squeeze.dim,
aten.slice.Tensor,
aten.slice_backward.default,
aten.split.Tensor,
aten.permute.default,
aten.t.default,
aten.transpose.int,
aten._to_copy.default,
aten.unsqueeze.default,
aten.unbind.int,
aten._unsafe_view.default,
aten.view.default,
aten.zero_.default,
aten.zeros_like.default,
]
for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit
for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit
else:
flop_mapping = {}
elementwise_flop_aten = {}
zero_flop_aten = {}

View File

@ -1,4 +1,3 @@
from .bias_addition import *
from .node_util import MetaInfo
from .symbolic_profile import symbolic_profile
from .symbolic_trace import symbolic_trace
from .tracer.symbolic_trace import symbolic_trace

View File

@ -1,4 +1,7 @@
import linecache
import os
import sys
import traceback
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Union
@ -6,11 +9,74 @@ from typing import Any, Dict, Optional, Union
import torch
import torch.fx
import torch.nn as nn
from torch.fx.graph import PythonCode, _PyTreeCodeGen
from torch.fx.graph_module import _exec_with_source, _forward_from_src, _WrappedCall
from torch.fx.graph import PythonCode
try:
from torch.fx.graph import _PyTreeCodeGen
SUPPORT_PT_CODEGEN = True
except ImportError:
SUPPORT_PT_CODEGEN = False
from torch.fx.graph_module import _exec_with_source, _forward_from_src
from torch.nn.modules.module import _addindent
# This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0.
class _WrappedCall:
def __init__(self, cls, cls_call):
self.cls = cls
self.cls_call = cls_call
# Previously, if an error occurred when valid
# symbolically-traced code was run with an invalid input, the
# user would see the source of the error as coming from
# `File "<eval_with_key_N">`, where N is some number. We use
# this function to generate a more informative error message. We
# return the traceback itself, a message explaining that the
# error occurred in a traced Module's generated forward
# function, and five lines of context surrounding the faulty
# line
@staticmethod
def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
# auxiliary variables (for readability)
err_lineno = frame_summary.lineno
assert err_lineno is not None
line = frame_summary.line
assert line is not None
err_line_len = len(line)
all_src_lines = linecache.getlines(frame_summary.filename)
# constituent substrings of the error message
tb_repr = traceback.format_exc()
custom_msg = ("Call using an FX-traced Module, "
f"line {err_lineno} of the traced Module's "
"generated forward function:")
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE"
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
# joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
def __call__(self, obj, *args, **kwargs):
try:
if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs)
else:
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
topmost_framesummary: traceback.FrameSummary = \
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
raise e.with_traceback(None)
else:
raise e
class ColoGraphModule(torch.fx.GraphModule):
"""
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
@ -65,7 +131,7 @@ class ColoGraphModule(torch.fx.GraphModule):
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
if isinstance(self._graph._codegen, _PyTreeCodeGen):
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')

View File

@ -20,7 +20,7 @@ def union(a, b):
return {**a, **b}
def compute_size_in_bytes(elem: torch.Tensor | Dict | List | Tuple | int) -> int:
def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Compute the size of a tensor or a collection of tensors in bytes.
Args:
@ -195,8 +195,8 @@ class MetaInfo:
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
if self.output_size:
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
if self.total_size:
s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
# if self.total_size:
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
if self.temp_size:
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
if self.backward_size:

View File

@ -111,7 +111,24 @@ class ShapeProp(torch.fx.Interpreter):
with self.global_hook:
r = getattr(self, n.op)(n.target, args, kwargs)
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
def unwrap_fn(elem):
def _convert_meta(t: torch.Tensor):
if t.device == 'meta':
return t
else:
return t.to('meta')
if isinstance(elem, MetaTensor):
return _convert_meta(elem._tensor)
elif isinstance(elem, torch.Tensor):
return _convert_meta(elem)
else:
return elem
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)

View File

@ -0,0 +1,2 @@
from .bias_addition import *
from .custom_leaf_module import *

View File

@ -4,11 +4,10 @@ graph construction to deal with the compatibility between bias-addition and all-
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _single, _triple
from .symbolic_trace import register_tracer_impl
from .tracer import register_tracer_impl
__all__ = []

View File

@ -0,0 +1,29 @@
import torch
from .tracer import register_leaf_module, register_leaf_module_impl
try:
import apex
register_leaf_module(apex.normalization.FusedLayerNorm)
register_leaf_module(apex.normalization.FusedRMSNorm)
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
register_leaf_module(apex.normalization.MixedFusedRMSNorm)
@register_leaf_module_impl(apex.normalization.FusedLayerNorm)
@register_leaf_module_impl(apex.normalization.FusedRMSNorm)
@register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm)
@register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm)
def torch_nn_normalize(self, input: torch.Tensor):
# check shape
if isinstance(self, torch.nn.BatchNorm1d):
assert input.dim() in [2, 3]
elif isinstance(self, torch.nn.BatchNorm2d):
assert input.dim() == 4
elif isinstance(self, torch.nn.BatchNorm3d):
assert input.dim() == 5
# normalization maintain the same shape as the input
return input.clone()
except (ImportError, AttributeError):
pass

View File

@ -0,0 +1,112 @@
import operator
from typing import Any, Callable, Dict, Optional, Set, Union
import torch
import torch.nn as nn
from torch.fx import Graph, Node, Proxy, Tracer
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
Target = Union[Callable[..., Any], str]
class ColoProxy(Proxy):
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs)
self._meta_data = data
@property
def meta_data(self):
return self._meta_data
@meta_data.setter
def meta_data(self, args):
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
self._meta_data = tree_map(wrap_fn, args)
@classmethod
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch:
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl
return proxy
else:
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if proxy.meta_data is None:
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
return proxy
@classmethod
def from_torch_proxy(cls, proxy: Proxy):
return cls(proxy.node, proxy.tracer)
def __repr__(self):
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
def __len__(self):
return len(self.meta_data)
def __int__(self):
return int(self.meta_data)
def __index__(self):
try:
return int(self.meta_data)
except:
return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
def __float__(self):
return float(self.meta_data)
def __bool__(self):
return self.meta_data
def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return False
return super().__contains__(key)
def __isinstancecheck__(self, type):
return isinstance(self.meta_data, type)
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._meta_data = data
self._node: Optional[Node] = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"

View File

@ -0,0 +1,157 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch
from torch.fx import Tracer
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
try:
from ..codegen import ActivationCheckpointCodeGen
SUPPORT_ACTIVATION = True
except:
SUPPORT_ACTIVATION = False
from ..graph_module import ColoGraphModule
from .tracer import ColoTracer
def _default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
def _current_device(module: torch.nn.Module):
try:
return next(module.parameters()).device
except:
return _default_device()
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt: bool = False,
bias_addition_split: bool = False,
) -> ColoGraphModule:
"""
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
attached to the ``Node``s.
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
This tracer is able to trace basic control flow and for loops.
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
(See ./bias_addition.py for more details).
Examples:
1. Tracing a ``torch.nn.Module`` with control flow.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
if x.size(0) > 1:
x = x.sum(dim=0)
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
# traced code like:
# def forward(self, x):
# linear_1 = self.linear(x)
# return linear_1
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
# traced code like:
# def forward(self, x):
# sum = x.sum(dim=0); x = None
# linear = self.linear(sum); sum = None
# return linear
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
def custom_forward(x):
return self.linear(x)
return torch.utils.checkpoint.checkpoint(custom_forward, x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
# traced code like:
# def checkpoint_0(self, x):
# linear = self.linear(x); x = None
# return linear
#
# def forward(self, x):
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
# return linear
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2, bias=True)
def forward(self, x):
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
# traced code like:
# def forward(self, x):
# linear_bias = self.linear.bias
# linear_weight = self.linear.weight
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# return add
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
Defaults to {}.
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
for tracing control flow. Defaults to {}.
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
Defaults to False.
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
Returns:
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
Remarks:
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
repo. We welcome any feedback and contributions to enhance the extensibility of
Colossal-AI.
"""
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
bias_addition_split=bias_addition_split).trace(root.to(device),
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
if trace_act_ckpt and SUPPORT_ACTIVATION:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)

View File

@ -1,28 +1,19 @@
import functools
import inspect
import operator
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
import torch
import torch.nn as nn
from torch.fx import Graph, Node, Proxy, Tracer
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor, _TensorPropertyMethod, _TorchFactoryMethod
from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod
from .codegen import ActivationCheckpointCodeGen
from .graph_module import ColoGraphModule
from .node_util import MetaInfo
from ..node_util import MetaInfo
from .proxy import ColoProxy
Target = Union[Callable[..., Any], str]
Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
List[Any], # actually Argument
Dict[str, Any], # actually Argument
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
'Node',]]
zeros = torch.zeros
def _truncate_suffix(s: str):
@ -32,17 +23,6 @@ def _truncate_suffix(s: str):
return re.sub(r'_\d+$', '', s)
def _default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
def _current_device(module):
try:
return next(module.parameters()).device
except:
return _default_device()
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
def wrapper(impl):
@ -70,149 +50,6 @@ def register_non_leaf_module(module: nn.Module):
ColoTracer._custom_non_leaf_module.add(module)
class ColoProxy(Proxy):
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs)
self._meta_data = data
@property
def meta_data(self):
return self._meta_data
@meta_data.setter
def meta_data(self, args):
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
self._meta_data = tree_map(wrap_fn, args)
@classmethod
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch:
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl
return proxy
else:
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if proxy.meta_data is None:
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
return proxy
@classmethod
def from_torch_proxy(cls, proxy: Proxy):
return cls(proxy.node, proxy.tracer)
def __repr__(self):
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
def __len__(self):
return len(self.meta_data)
def __int__(self):
return int(self.meta_data)
def __index__(self):
try:
return int(self.meta_data)
except:
return zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
def __float__(self):
return float(self.meta_data)
def __bool__(self):
return self.meta_data
def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return False
return super().__contains__(key)
def __isinstancecheck__(self, type):
return isinstance(self.meta_data, type)
def size(self, dim=None):
if self._meta_data is None:
return self._meta_data.size(*[dim] if dim else [])
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
def dim(self):
if self._meta_data is not None:
return self._meta_data.dim()
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
@property
def shape(self):
if self._meta_data is not None:
return self._meta_data.shape
return self.tracer.create_proxy('call_function', getattr, (self, 'shape'), {})
@property
def ndim(self):
if self._meta_data is not None:
return self._meta_data.ndim
return self.tracer.create_proxy('call_function', getattr, (self, 'ndim'), {})
@property
def device(self):
if self._meta_data is not None:
return self._meta_data.device
return self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
@property
def dtype(self):
if self._meta_data is not None:
return self._meta_data.dtype
return self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
def to(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
def cpu(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
def cuda(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._meta_data = data
self._node: Optional[Node] = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
class ColoTracer(Tracer):
_custom_leaf_module: Set[Type[nn.Module]] = set()
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
@ -249,7 +86,6 @@ class ColoTracer(Tracer):
# we will enter the module and split the bias-addition ops
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False
# user can specify which modules are leaf modules and which are not
return (type(m) not in self._custom_non_leaf_module
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
@ -306,9 +142,13 @@ class ColoTracer(Tracer):
mod = self.root.get_submodule(target)
self.disable_module_getattr = True
try:
proxy.meta_data = self._custom_leaf_module_impl.get(type(mod),
mod.forward)(*tree_map(unwrap_fn, args),
**tree_map(unwrap_fn, kwargs))
args = tree_map(unwrap_fn, args)
kwargs = tree_map(unwrap_fn, kwargs)
if type(mod) in self._custom_leaf_module:
target = self._custom_leaf_module_impl[type(mod)]
proxy.meta_data = target(mod, *args, **kwargs)
else:
proxy.meta_data = mod.forward(*args, **kwargs)
finally:
self.disable_module_getattr = False
return proxy
@ -320,15 +160,21 @@ class ColoTracer(Tracer):
def trace(self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = {},
meta_args: Optional[Dict[str, torch.Tensor]] = {}) -> Graph:
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names
# update concrete args with default values
for k, v in sig.parameters.items():
if k in sig_names - meta_arg_names and \
@ -352,6 +198,34 @@ class ColoTracer(Tracer):
self.graph = super().trace(root, concrete_args=concrete_args)
self.mod_dir = ''
self.graph.lint()
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
return self.graph
@contextmanager
@ -454,7 +328,7 @@ class ColoTracer(Tracer):
if node.op == "output":
node.type = None
self.graph.lint()
def getattr(self, attr, attr_val, parameter_proxy_cache):
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
@ -487,134 +361,3 @@ class ColoTracer(Tracer):
return maybe_parameter_proxy
return attr_val
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = {},
meta_args: Optional[Dict[str, Any]] = {},
trace_act_ckpt: bool = False,
bias_addition_split: bool = False,
) -> ColoGraphModule:
"""
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
attached to the ``Node``s.
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
This tracer is able to trace basic control flow and for loops.
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
(See ./bias_addition.py for more details).
Examples:
1. Tracing a ``torch.nn.Module`` with control flow.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
if x.size(0) > 1:
x = x.sum(dim=0)
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
# traced code like:
# def forward(self, x):
# linear_1 = self.linear(x)
# return linear_1
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
# traced code like:
# def forward(self, x):
# sum = x.sum(dim=0); x = None
# linear = self.linear(sum); sum = None
# return linear
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
def custom_forward(x):
return self.linear(x)
return torch.utils.checkpoint.checkpoint(custom_forward, x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
# traced code like:
# def checkpoint_0(self, x):
# linear = self.linear(x); x = None
# return linear
#
# def forward(self, x):
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
# return linear
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2, bias=True)
def forward(self, x):
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
# traced code like:
# def forward(self, x):
# linear_bias = self.linear.bias
# linear_weight = self.linear.weight
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# return add
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
Defaults to {}.
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
for tracing control flow. Defaults to {}.
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
Defaults to False.
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
Returns:
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
Remarks:
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
repo. We welcome any feedback and contributions to enhance the extensibility of
Colossal-AI.
"""
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
bias_addition_split=bias_addition_split).trace(root.to(device),
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
if trace_act_ckpt:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)

View File

@ -1,5 +1,4 @@
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
from .registry import model_zoo
__all__ = ['model_zoo']

View File

@ -17,6 +17,14 @@ def data_gen():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def seq_classification_data_gen():
# batch sizes should be 1 if no padding token is defined.
input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
output_transform_fn = lambda x: x
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
@ -44,6 +52,6 @@ model_zoo.register(name='transformers_gpt_for_token_classification',
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_sequence_classification',
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
data_gen_fn=data_gen,
data_gen_fn=seq_classification_data_gen,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -1,5 +1,6 @@
import pytest
import torch
from packaging import version
from torch.utils.checkpoint import checkpoint
try:
@ -73,7 +74,7 @@ class AddmmModel(torch.nn.Module):
return x
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])

View File

@ -3,7 +3,8 @@ from numpy import isin
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
from colossalai.fx import symbolic_trace
# from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen):

View File

@ -1,4 +1,7 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@ -6,6 +9,7 @@ BATCH_SIZE = 2
SEQ_LENGTH = 16
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_albert():
sub_registry = model_zoo.get_sub_registry('transformers_albert')

View File

@ -1,8 +1,12 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_bert():
sub_registry = model_zoo.get_sub_registry('transformers_bert')

View File

@ -1,16 +1,24 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
# TODO: remove this skip once we handle the latest gpt model
@pytest.mark.skip
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_gpt():
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
model = model_fn()
# TODO: support the following models
# 1. GPT2DoubleHeadsModel
# as they are not supported, let's skip them
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
continue
trace_model_and_compare_output(model, data_gen_fn)

View File

@ -1,8 +1,12 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_opt():
sub_registry = model_zoo.get_sub_registry('transformers_opt')

View File

@ -1,8 +1,12 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_t5():
sub_registry = model_zoo.get_sub_registry('transformers_t5')

View File

@ -1,8 +1,8 @@
import pytest
import timm.models as tm
import torch
from packaging import version
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo
@ -42,6 +42,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_timm_models():
torch.backends.cudnn.deterministic = True

View File

@ -1,20 +1,18 @@
import re
import pytest
import torch
from packaging import version
from torchaudio_utils import trace_and_compare
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_torchaudio_models():
torch.backends.cudnn.deterministic = True
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
# FIXME(ver217): temporarily skip these models
if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name):
continue
model = model_fn()
trace_and_compare(model,
data_gen_fn,

View File

@ -1,6 +1,6 @@
import torch
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):

View File

@ -1,7 +1,7 @@
import pytest
import torch
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo
BATCH = 2

View File

@ -1,7 +1,7 @@
import pytest
import torch
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo
BATCH = 2

View File

@ -1,6 +1,6 @@
import torch
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo