mirror of https://github.com/hpcaitech/ColossalAI
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polishpull/3197/head
parent
b429529365
commit
f57d34958b
|
@ -6,11 +6,15 @@
|
|||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||
try:
|
||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||
except AttributeError:
|
||||
meta_lib = None
|
||||
|
||||
meta_table = {}
|
||||
|
||||
|
@ -50,432 +54,411 @@ def register_meta(op, register_dispatcher=True):
|
|||
return wrapper
|
||||
|
||||
|
||||
# ============================== Convolutions ======================================
|
||||
# https://github.com/pytorch/pytorch/pull/79834
|
||||
@register_meta(aten.convolution.default)
|
||||
def meta_conv(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
):
|
||||
|
||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
Args:
|
||||
ln: length of the dimension
|
||||
p: padding in that dim
|
||||
d: dilation in that dim
|
||||
k: kernel size in that dim
|
||||
s: stride in that dim
|
||||
Returns:
|
||||
The output length
|
||||
"""
|
||||
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
|
||||
|
||||
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
if transposed convolution is used.
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
Args:
|
||||
ln: length of the dimension
|
||||
p: padding in that dim
|
||||
d: dilation in that dim
|
||||
k: kernel size in that dim
|
||||
s: stride in that dim
|
||||
op: output padding in that dim
|
||||
Returns:
|
||||
The output length
|
||||
"""
|
||||
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
|
||||
|
||||
def calc_conv_nd_return_shape(
|
||||
dims: torch.Size,
|
||||
kernel_size: torch.Size,
|
||||
stride: Union[List[int], int],
|
||||
padding: Union[List[int], int],
|
||||
dilation: Union[List[int], int],
|
||||
output_padding: Optional[Union[List[int], int]] = None,
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
# ============================== Convolutions ======================================
|
||||
# https://github.com/pytorch/pytorch/pull/79834
|
||||
@register_meta(aten.convolution.default)
|
||||
def meta_conv(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
):
|
||||
ret_shape = []
|
||||
if isinstance(stride, int):
|
||||
stride = [stride] * len(dims)
|
||||
elif len(stride) == 1:
|
||||
stride = [stride[0]] * len(dims)
|
||||
|
||||
if isinstance(padding, int):
|
||||
padding = [padding] * len(dims)
|
||||
elif len(padding) == 1:
|
||||
padding = [padding[0]] * len(dims)
|
||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
Args:
|
||||
ln: length of the dimension
|
||||
p: padding in that dim
|
||||
d: dilation in that dim
|
||||
k: kernel size in that dim
|
||||
s: stride in that dim
|
||||
Returns:
|
||||
The output length
|
||||
"""
|
||||
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
|
||||
|
||||
if isinstance(dilation, int):
|
||||
dilation = [dilation] * len(dims)
|
||||
elif len(dilation) == 1:
|
||||
dilation = [dilation[0]] * len(dims)
|
||||
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
if transposed convolution is used.
|
||||
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
Args:
|
||||
ln: length of the dimension
|
||||
p: padding in that dim
|
||||
d: dilation in that dim
|
||||
k: kernel size in that dim
|
||||
s: stride in that dim
|
||||
op: output padding in that dim
|
||||
Returns:
|
||||
The output length
|
||||
"""
|
||||
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
|
||||
|
||||
output_padding_list: Optional[List[int]] = None
|
||||
if output_padding:
|
||||
if isinstance(output_padding, int):
|
||||
output_padding_list = [output_padding] * len(dims)
|
||||
elif len(output_padding) == 1:
|
||||
output_padding_list = [output_padding[0]] * len(dims)
|
||||
else:
|
||||
output_padding_list = output_padding
|
||||
def calc_conv_nd_return_shape(
|
||||
dims: torch.Size,
|
||||
kernel_size: torch.Size,
|
||||
stride: Union[List[int], int],
|
||||
padding: Union[List[int], int],
|
||||
dilation: Union[List[int], int],
|
||||
output_padding: Optional[Union[List[int], int]] = None,
|
||||
):
|
||||
ret_shape = []
|
||||
if isinstance(stride, int):
|
||||
stride = [stride] * len(dims)
|
||||
elif len(stride) == 1:
|
||||
stride = [stride[0]] * len(dims)
|
||||
|
||||
for i in range(len(dims)):
|
||||
# If output_padding is present, we are dealing with a transposed convolution
|
||||
if output_padding_list:
|
||||
ret_shape.append(
|
||||
_formula_transposed(
|
||||
dims[i],
|
||||
padding[i],
|
||||
dilation[i],
|
||||
kernel_size[i],
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
))
|
||||
else:
|
||||
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
||||
return ret_shape
|
||||
if isinstance(padding, int):
|
||||
padding = [padding] * len(dims)
|
||||
elif len(padding) == 1:
|
||||
padding = [padding[0]] * len(dims)
|
||||
|
||||
def pick_memory_format():
|
||||
if input_tensor.is_contiguous(memory_format=torch.channels_last):
|
||||
return torch.channels_last
|
||||
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
||||
return torch.contiguous_format
|
||||
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
||||
return torch.preserve_format
|
||||
if isinstance(dilation, int):
|
||||
dilation = [dilation] * len(dims)
|
||||
elif len(dilation) == 1:
|
||||
dilation = [dilation[0]] * len(dims)
|
||||
|
||||
kernel_size = weight.shape[2:]
|
||||
dims = input_tensor.shape[2:]
|
||||
if is_transposed:
|
||||
out_channels = groups * weight.shape[1]
|
||||
output_padding_list: Optional[List[int]] = None
|
||||
if output_padding:
|
||||
if isinstance(output_padding, int):
|
||||
output_padding_list = [output_padding] * len(dims)
|
||||
elif len(output_padding) == 1:
|
||||
output_padding_list = [output_padding[0]] * len(dims)
|
||||
else:
|
||||
output_padding_list = output_padding
|
||||
|
||||
shape_out = calc_conv_nd_return_shape(
|
||||
dims,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
output_padding,
|
||||
)
|
||||
for i in range(len(dims)):
|
||||
# If output_padding is present, we are dealing with a transposed convolution
|
||||
if output_padding_list:
|
||||
ret_shape.append(
|
||||
_formula_transposed(
|
||||
dims[i],
|
||||
padding[i],
|
||||
dilation[i],
|
||||
kernel_size[i],
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
))
|
||||
else:
|
||||
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
||||
return ret_shape
|
||||
|
||||
else:
|
||||
out_channels = weight.shape[0]
|
||||
if weight.shape[1] != input_tensor.shape[1] / groups:
|
||||
raise RuntimeError("Invalid channel dimensions")
|
||||
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||
mem_fmt = pick_memory_format()
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
return out
|
||||
def pick_memory_format():
|
||||
if input_tensor.is_contiguous(memory_format=torch.channels_last):
|
||||
return torch.channels_last
|
||||
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
||||
return torch.contiguous_format
|
||||
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
||||
return torch.preserve_format
|
||||
|
||||
kernel_size = weight.shape[2:]
|
||||
dims = input_tensor.shape[2:]
|
||||
if is_transposed:
|
||||
out_channels = groups * weight.shape[1]
|
||||
|
||||
@register_meta(aten._convolution.default)
|
||||
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||
*extra_args):
|
||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||
return out
|
||||
shape_out = calc_conv_nd_return_shape(
|
||||
dims,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
output_padding,
|
||||
)
|
||||
|
||||
else:
|
||||
out_channels = weight.shape[0]
|
||||
if weight.shape[1] != input_tensor.shape[1] / groups:
|
||||
raise RuntimeError("Invalid channel dimensions")
|
||||
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||
mem_fmt = pick_memory_format()
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
return out
|
||||
|
||||
@register_meta(aten.convolution_backward.default)
|
||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
return new_like(input), new_like(weight), new((bias_sizes))
|
||||
@register_meta(aten._convolution.default)
|
||||
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||
*extra_args):
|
||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||
return out
|
||||
|
||||
@register_meta(aten.convolution_backward.default)
|
||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
return new_like(input), new_like(weight), new((bias_sizes))
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||
def meta_adaptive_avg_pool2d_backward(
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
):
|
||||
return new_like(input)
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||
def meta_adaptive_avg_pool2d_backward(
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
):
|
||||
return new_like(input)
|
||||
|
||||
# ================================ RNN =============================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn.default)
|
||||
def meta_cuda_rnn(
|
||||
input,
|
||||
weight,
|
||||
weight_stride0,
|
||||
weight_buf,
|
||||
hx,
|
||||
cx,
|
||||
mode,
|
||||
hidden_size,
|
||||
proj_size,
|
||||
num_layers,
|
||||
batch_first,
|
||||
dropout,
|
||||
train,
|
||||
bidirectional,
|
||||
batch_sizes,
|
||||
dropout_state,
|
||||
):
|
||||
|
||||
# ================================ RNN =============================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn.default)
|
||||
def meta_cuda_rnn(
|
||||
input,
|
||||
weight,
|
||||
weight_stride0,
|
||||
weight_buf,
|
||||
hx,
|
||||
cx,
|
||||
mode,
|
||||
hidden_size,
|
||||
proj_size,
|
||||
num_layers,
|
||||
batch_first,
|
||||
dropout,
|
||||
train,
|
||||
bidirectional,
|
||||
batch_sizes,
|
||||
dropout_state,
|
||||
):
|
||||
is_input_packed = len(batch_sizes) != 0
|
||||
if is_input_packed:
|
||||
seq_length = len(batch_sizes)
|
||||
mini_batch = batch_sizes[0]
|
||||
batch_sizes_sum = input.shape[0]
|
||||
else:
|
||||
seq_length = input.shape[1] if batch_first else input.shape[0]
|
||||
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
||||
batch_sizes_sum = -1
|
||||
|
||||
is_input_packed = len(batch_sizes) != 0
|
||||
if is_input_packed:
|
||||
seq_length = len(batch_sizes)
|
||||
mini_batch = batch_sizes[0]
|
||||
batch_sizes_sum = input.shape[0]
|
||||
else:
|
||||
seq_length = input.shape[1] if batch_first else input.shape[0]
|
||||
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
||||
batch_sizes_sum = -1
|
||||
num_directions = 2 if bidirectional else 1
|
||||
out_size = proj_size if proj_size != 0 else hidden_size
|
||||
if is_input_packed:
|
||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||
else:
|
||||
out_shape = ([mini_batch, seq_length, out_size *
|
||||
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
||||
output = input.new_empty(out_shape)
|
||||
|
||||
num_directions = 2 if bidirectional else 1
|
||||
out_size = proj_size if proj_size != 0 else hidden_size
|
||||
if is_input_packed:
|
||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||
else:
|
||||
out_shape = ([mini_batch, seq_length, out_size *
|
||||
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
||||
output = input.new_empty(out_shape)
|
||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||
cy = new(0) if cx is None else cx.new_empty(cell_shape)
|
||||
|
||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||
cy = new(0) if cx is None else cx.new_empty(cell_shape)
|
||||
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
|
||||
|
||||
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
|
||||
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
|
||||
reserve_shape = 0 if train else 0
|
||||
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
|
||||
|
||||
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
|
||||
reserve_shape = 0 if train else 0
|
||||
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
|
||||
return output, hy, cy, reserve, weight_buf
|
||||
|
||||
return output, hy, cy, reserve, weight_buf
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn_backward.default)
|
||||
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
|
||||
()) # (grad_input, grad_weight, grad_hx, grad_cx)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||
# ============================== Activations =======================================
|
||||
_unregistered_ewise = [
|
||||
aten.relu.default,
|
||||
aten.prelu.default,
|
||||
aten.hardswish.default,
|
||||
aten.hardtanh.default,
|
||||
aten.prelu_backward.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh_backward.default,
|
||||
]
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn_backward.default)
|
||||
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
|
||||
()) # (grad_input, grad_weight, grad_hx, grad_cx)
|
||||
@register_meta(_unregistered_ewise)
|
||||
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
||||
return new_like(input)
|
||||
|
||||
# ============================== Normalization =====================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm.default)
|
||||
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
return new_like(input), new((n_input)), new((n_input))
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||
# ============================== Activations =======================================
|
||||
_unregistered_ewise = [
|
||||
aten.relu.default,
|
||||
aten.prelu.default,
|
||||
aten.hardswish.default,
|
||||
aten.hardtanh.default,
|
||||
aten.prelu_backward.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh_backward.default,
|
||||
]
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm_backward.default)
|
||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, train, eps, output_mask):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.cudnn_batch_norm.default)
|
||||
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
return new_like(input), new((n_input)), new((n_input)), new(
|
||||
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
|
||||
|
||||
@register_meta(_unregistered_ewise)
|
||||
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
||||
return new_like(input)
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
# NB: CuDNN only implements the backward algorithm for batchnorm
|
||||
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
# which is why this doesn't accept a 'training' parameter.
|
||||
@register_meta(aten.cudnn_batch_norm_backward.default)
|
||||
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, eps, reserve):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
bs, n_input = input.size(0), input.size(1)
|
||||
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
||||
|
||||
# ============================== Normalization =====================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm.default)
|
||||
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
return new_like(input), new((n_input)), new((n_input))
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm_backward.default)
|
||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||
grad_input_mask):
|
||||
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
||||
|
||||
# ================================== Misc ==========================================
|
||||
# Maybe incorrect
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
|
||||
@register_meta(aten.im2col.default)
|
||||
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
||||
return new_like(input)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm_backward.default)
|
||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
||||
save_invstd, train, eps, output_mask):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.eye.m_out)
|
||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||
return out
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.roll.default)
|
||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||
return input
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.cudnn_batch_norm.default)
|
||||
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||
n_input = input.size(1)
|
||||
return new_like(input), new((n_input)), new((n_input)), new(
|
||||
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
|
||||
@register_meta(aten._local_scalar_dense.default)
|
||||
def meta_local_scalar_dense(self: torch.Tensor):
|
||||
return 0
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||
@register_meta(aten.where.self)
|
||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||
result_type = torch.result_type(self, other)
|
||||
return new_like(condition + self + other, dtype=result_type)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
# NB: CuDNN only implements the backward algorithm for batchnorm
|
||||
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
# which is why this doesn't accept a 'training' parameter.
|
||||
@register_meta(aten.cudnn_batch_norm_backward.default)
|
||||
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, eps, reserve):
|
||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm.default)
|
||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
bs, n_input = input.size(0), input.size(1)
|
||||
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm_backward.default)
|
||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||
grad_input_mask):
|
||||
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
||||
|
||||
|
||||
# ================================== Misc ==========================================
|
||||
# Maybe incorrect
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
|
||||
@register_meta(aten.im2col.default)
|
||||
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
||||
return new_like(input)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.eye.m_out)
|
||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||
return out
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.roll.default)
|
||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||
return input
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
|
||||
@register_meta(aten._local_scalar_dense.default)
|
||||
def meta_local_scalar_dense(self: torch.Tensor):
|
||||
return 0
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||
@register_meta(aten.where.self)
|
||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||
result_type = torch.result_type(self, other)
|
||||
return new_like(condition + self + other, dtype=result_type)
|
||||
|
||||
|
||||
@register_meta(aten.index.Tensor)
|
||||
def meta_index_Tensor(self, indices):
|
||||
assert indices, "at least one index must be provided"
|
||||
# aten::index is the internal advanced indexing implementation
|
||||
# checkIndexTensorTypes and expandTensors
|
||||
result: List[Optional[torch.Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||
"tensors used as indices must be long, byte or bool tensors"
|
||||
if index.dtype in [torch.int8, torch.bool]:
|
||||
nonzero = index.nonzero()
|
||||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[
|
||||
k +
|
||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
@register_meta(aten.index.Tensor)
|
||||
def meta_index_Tensor(self, indices):
|
||||
assert indices, "at least one index must be provided"
|
||||
# aten::index is the internal advanced indexing implementation
|
||||
# checkIndexTensorTypes and expandTensors
|
||||
result: List[Optional[torch.Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||
"tensors used as indices must be long, byte or bool tensors"
|
||||
if index.dtype in [torch.int8, torch.bool]:
|
||||
nonzero = index.nonzero()
|
||||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[
|
||||
k +
|
||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
else:
|
||||
result.append(index)
|
||||
else:
|
||||
result.append(index)
|
||||
else:
|
||||
result.append(index)
|
||||
indices = result
|
||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
# expand_outplace
|
||||
import torch._refs as refs
|
||||
indices = result
|
||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
# expand_outplace
|
||||
import torch._refs as refs
|
||||
|
||||
indices = list(refs._maybe_broadcast(*indices))
|
||||
# add missing null tensors
|
||||
while len(indices) < self.ndim:
|
||||
indices.append(None)
|
||||
indices = list(refs._maybe_broadcast(*indices))
|
||||
# add missing null tensors
|
||||
while len(indices) < self.ndim:
|
||||
indices.append(None)
|
||||
|
||||
# hasContiguousSubspace
|
||||
# true if all non-null tensors are adjacent
|
||||
# See:
|
||||
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
||||
state = 0
|
||||
has_contiguous_subspace = False
|
||||
for index in indices:
|
||||
if state == 0:
|
||||
if index is not None:
|
||||
state = 1
|
||||
elif state == 1:
|
||||
if index is None:
|
||||
state = 2
|
||||
else:
|
||||
if index is not None:
|
||||
break
|
||||
else:
|
||||
has_contiguous_subspace = True
|
||||
|
||||
# transposeToFront
|
||||
# This is the logic that causes the newly inserted dimensions to show up
|
||||
# at the beginning of the tensor, if they're not contiguous
|
||||
if not has_contiguous_subspace:
|
||||
dims = []
|
||||
transposed_indices = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
for i, index in enumerate(indices):
|
||||
if index is None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
self = self.permute(dims)
|
||||
indices = transposed_indices
|
||||
|
||||
# AdvancedIndex::AdvancedIndex
|
||||
# Now we can assume the indices have contiguous subspace
|
||||
# This is simplified from AdvancedIndex which goes to more effort
|
||||
# to put the input and indices in a form so that TensorIterator can
|
||||
# take them. If we write a ref for this, probably that logic should
|
||||
# get implemented
|
||||
before_shape: List[int] = []
|
||||
after_shape: List[int] = []
|
||||
replacement_shape: List[int] = []
|
||||
for dim, index in enumerate(indices):
|
||||
if index is None:
|
||||
if replacement_shape:
|
||||
after_shape.append(self.shape[dim])
|
||||
# hasContiguousSubspace
|
||||
# true if all non-null tensors are adjacent
|
||||
# See:
|
||||
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
||||
state = 0
|
||||
has_contiguous_subspace = False
|
||||
for index in indices:
|
||||
if state == 0:
|
||||
if index is not None:
|
||||
state = 1
|
||||
elif state == 1:
|
||||
if index is None:
|
||||
state = 2
|
||||
else:
|
||||
before_shape.append(self.shape[dim])
|
||||
if index is not None:
|
||||
break
|
||||
else:
|
||||
replacement_shape = list(index.shape)
|
||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
has_contiguous_subspace = True
|
||||
|
||||
# transposeToFront
|
||||
# This is the logic that causes the newly inserted dimensions to show up
|
||||
# at the beginning of the tensor, if they're not contiguous
|
||||
if not has_contiguous_subspace:
|
||||
dims = []
|
||||
transposed_indices = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
for i, index in enumerate(indices):
|
||||
if index is None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
self = self.permute(dims)
|
||||
indices = transposed_indices
|
||||
|
||||
# ============================== Embedding =========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
return new((num_weights, grad_output.size(-1)),
|
||||
dtype=grad_output.dtype,
|
||||
device=grad_output.device,
|
||||
layout=grad_output.layout)
|
||||
# AdvancedIndex::AdvancedIndex
|
||||
# Now we can assume the indices have contiguous subspace
|
||||
# This is simplified from AdvancedIndex which goes to more effort
|
||||
# to put the input and indices in a form so that TensorIterator can
|
||||
# take them. If we write a ref for this, probably that logic should
|
||||
# get implemented
|
||||
before_shape: List[int] = []
|
||||
after_shape: List[int] = []
|
||||
replacement_shape: List[int] = []
|
||||
for dim, index in enumerate(indices):
|
||||
if index is None:
|
||||
if replacement_shape:
|
||||
after_shape.append(self.shape[dim])
|
||||
else:
|
||||
before_shape.append(self.shape[dim])
|
||||
else:
|
||||
replacement_shape = list(index.shape)
|
||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
|
||||
# ============================== Embedding =========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
return new((num_weights, grad_output.size(-1)),
|
||||
dtype=grad_output.dtype,
|
||||
device=grad_output.device,
|
||||
layout=grad_output.layout)
|
||||
|
||||
# ============================== Dropout ===========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout.default)
|
||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||
# notice that mask is bool
|
||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||
# ============================== Dropout ===========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout.default)
|
||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||
# notice that mask is bool
|
||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout_backward.default)
|
||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||
return new_like(grad) # (grad_in)
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout_backward.default)
|
||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||
return new_like(grad) # (grad_in)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
@ -49,40 +50,45 @@ _DistCommMethod = [
|
|||
"scatter",
|
||||
]
|
||||
|
||||
# TODO: dive deep here
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
_AliasATen = [
|
||||
aten.detach.default,
|
||||
aten.detach_.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
# TODO: dive deep here
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
_AliasATen = [
|
||||
aten.detach.default,
|
||||
aten.detach_.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
|
||||
_InplaceATen = [
|
||||
aten.add_.Tensor,
|
||||
aten.add_.Scalar,
|
||||
aten.sub_.Tensor,
|
||||
aten.sub_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.mul_.Scalar,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.pow_.Tensor,
|
||||
aten.pow_.Scalar,
|
||||
]
|
||||
_InplaceATen = [
|
||||
aten.add_.Tensor,
|
||||
aten.add_.Scalar,
|
||||
aten.sub_.Tensor,
|
||||
aten.sub_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.mul_.Scalar,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.pow_.Tensor,
|
||||
aten.pow_.Scalar,
|
||||
]
|
||||
|
||||
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
|
||||
_MaybeInplaceATen = [
|
||||
aten.diagonal.default,
|
||||
aten.expand.default,
|
||||
aten.select.int,
|
||||
aten.slice.Tensor,
|
||||
aten.split.Tensor,
|
||||
aten.squeeze.default,
|
||||
aten.permute.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.as_strided.default,
|
||||
]
|
||||
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
|
||||
_MaybeInplaceATen = [
|
||||
aten.diagonal.default,
|
||||
aten.expand.default,
|
||||
aten.select.int,
|
||||
aten.slice.Tensor,
|
||||
aten.split.Tensor,
|
||||
aten.squeeze.default,
|
||||
aten.permute.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.as_strided.default,
|
||||
]
|
||||
else:
|
||||
_AliasATen = []
|
||||
_InplaceATen = []
|
||||
_MaybeInplaceATen = []
|
||||
|
|
|
@ -11,6 +11,7 @@ from numbers import Number
|
|||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from .meta_tensor import MetaTensor
|
||||
|
@ -403,134 +404,139 @@ def zero_flop_jit(*args):
|
|||
return 0
|
||||
|
||||
|
||||
flop_mapping = {
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
flop_mapping = {
|
||||
# gemm
|
||||
aten.mm.default: matmul_flop_jit,
|
||||
aten.matmul.default: matmul_flop_jit,
|
||||
aten.addmm.default: addmm_flop_jit,
|
||||
aten.bmm.default: bmm_flop_jit,
|
||||
aten.mm.default: matmul_flop_jit,
|
||||
aten.matmul.default: matmul_flop_jit,
|
||||
aten.addmm.default: addmm_flop_jit,
|
||||
aten.bmm.default: bmm_flop_jit,
|
||||
|
||||
# convolution
|
||||
aten.convolution.default: conv_flop_jit,
|
||||
aten._convolution.default: conv_flop_jit,
|
||||
aten.convolution_backward.default: conv_backward_flop_jit,
|
||||
aten.convolution.default: conv_flop_jit,
|
||||
aten._convolution.default: conv_flop_jit,
|
||||
aten.convolution_backward.default: conv_backward_flop_jit,
|
||||
|
||||
# normalization
|
||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||
|
||||
# pooling
|
||||
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding.default: ewise_flop_counter(1, 0),
|
||||
}
|
||||
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding.default: ewise_flop_counter(1, 0),
|
||||
}
|
||||
|
||||
ewise_flop_aten = [
|
||||
ewise_flop_aten = [
|
||||
# basic op
|
||||
aten.add.Tensor,
|
||||
aten.add_.Tensor,
|
||||
aten.div.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div.Scalar,
|
||||
aten.div_.Scalar,
|
||||
aten.mul.Tensor,
|
||||
aten.mul.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.neg.default,
|
||||
aten.pow.Tensor_Scalar,
|
||||
aten.rsub.Scalar,
|
||||
aten.sum.default,
|
||||
aten.sum.dim_IntList,
|
||||
aten.mean.dim,
|
||||
aten.add.Tensor,
|
||||
aten.add_.Tensor,
|
||||
aten.div.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div.Scalar,
|
||||
aten.div_.Scalar,
|
||||
aten.mul.Tensor,
|
||||
aten.mul.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.neg.default,
|
||||
aten.pow.Tensor_Scalar,
|
||||
aten.rsub.Scalar,
|
||||
aten.sum.default,
|
||||
aten.sum.dim_IntList,
|
||||
aten.mean.dim,
|
||||
|
||||
# activation op
|
||||
aten.hardswish.default,
|
||||
aten.hardswish_.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh.default,
|
||||
aten.hardtanh_.default,
|
||||
aten.hardtanh_backward.default,
|
||||
aten.hardsigmoid_backward.default,
|
||||
aten.hardsigmoid.default,
|
||||
aten.gelu.default,
|
||||
aten.gelu_backward.default,
|
||||
aten.silu.default,
|
||||
aten.silu_.default,
|
||||
aten.silu_backward.default,
|
||||
aten.sigmoid.default,
|
||||
aten.sigmoid_backward.default,
|
||||
aten._softmax.default,
|
||||
aten._softmax_backward_data.default,
|
||||
aten.relu_.default,
|
||||
aten.relu.default,
|
||||
aten.tanh.default,
|
||||
aten.tanh_backward.default,
|
||||
aten.threshold_backward.default,
|
||||
aten.hardswish.default,
|
||||
aten.hardswish_.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh.default,
|
||||
aten.hardtanh_.default,
|
||||
aten.hardtanh_backward.default,
|
||||
aten.hardsigmoid_backward.default,
|
||||
aten.hardsigmoid.default,
|
||||
aten.gelu.default,
|
||||
aten.gelu_backward.default,
|
||||
aten.silu.default,
|
||||
aten.silu_.default,
|
||||
aten.silu_backward.default,
|
||||
aten.sigmoid.default,
|
||||
aten.sigmoid_backward.default,
|
||||
aten._softmax.default,
|
||||
aten._softmax_backward_data.default,
|
||||
aten.relu_.default,
|
||||
aten.relu.default,
|
||||
aten.tanh.default,
|
||||
aten.tanh_backward.default,
|
||||
aten.threshold_backward.default,
|
||||
|
||||
# dropout
|
||||
aten.native_dropout.default,
|
||||
aten.native_dropout_backward.default,
|
||||
aten.native_dropout.default,
|
||||
aten.native_dropout_backward.default,
|
||||
|
||||
# distribution
|
||||
aten.bernoulli_.float,
|
||||
aten.bernoulli_.float,
|
||||
|
||||
# where
|
||||
aten.where.self,
|
||||
]
|
||||
for op in ewise_flop_aten:
|
||||
flop_mapping[op] = ewise_flop_counter(1, 0)
|
||||
aten.where.self,
|
||||
]
|
||||
for op in ewise_flop_aten:
|
||||
flop_mapping[op] = ewise_flop_counter(1, 0)
|
||||
|
||||
# fix-me: this will be removed in future
|
||||
zero_flop_aten = [
|
||||
aten.as_strided.default,
|
||||
aten.as_strided_.default,
|
||||
aten.cat.default,
|
||||
aten.clone.default,
|
||||
aten.copy_.default,
|
||||
aten.detach.default,
|
||||
aten.expand.default,
|
||||
aten.empty_like.default,
|
||||
aten.new_empty.default,
|
||||
aten.new_empty_strided.default,
|
||||
aten.ones_like.default,
|
||||
aten._reshape_alias.default,
|
||||
aten.select.int,
|
||||
aten.select_backward.default,
|
||||
aten.squeeze.dim,
|
||||
aten.slice.Tensor,
|
||||
aten.slice_backward.default,
|
||||
aten.split.Tensor,
|
||||
aten.permute.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten._to_copy.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.unbind.int,
|
||||
aten._unsafe_view.default,
|
||||
aten.view.default,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
# fix-me: this will be removed in future
|
||||
zero_flop_aten = [
|
||||
aten.as_strided.default,
|
||||
aten.as_strided_.default,
|
||||
aten.cat.default,
|
||||
aten.clone.default,
|
||||
aten.copy_.default,
|
||||
aten.detach.default,
|
||||
aten.expand.default,
|
||||
aten.empty_like.default,
|
||||
aten.new_empty.default,
|
||||
aten.new_empty_strided.default,
|
||||
aten.ones_like.default,
|
||||
aten._reshape_alias.default,
|
||||
aten.select.int,
|
||||
aten.select_backward.default,
|
||||
aten.squeeze.dim,
|
||||
aten.slice.Tensor,
|
||||
aten.slice_backward.default,
|
||||
aten.split.Tensor,
|
||||
aten.permute.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten._to_copy.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.unbind.int,
|
||||
aten._unsafe_view.default,
|
||||
aten.view.default,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
|
||||
for op in zero_flop_aten:
|
||||
flop_mapping[op] = zero_flop_jit
|
||||
for op in zero_flop_aten:
|
||||
flop_mapping[op] = zero_flop_jit
|
||||
else:
|
||||
flop_mapping = {}
|
||||
elementwise_flop_aten = {}
|
||||
zero_flop_aten = {}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from .bias_addition import *
|
||||
from .node_util import MetaInfo
|
||||
from .symbolic_profile import symbolic_profile
|
||||
from .symbolic_trace import symbolic_trace
|
||||
from .tracer.symbolic_trace import symbolic_trace
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import linecache
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
@ -6,11 +9,74 @@ from typing import Any, Dict, Optional, Union
|
|||
import torch
|
||||
import torch.fx
|
||||
import torch.nn as nn
|
||||
from torch.fx.graph import PythonCode, _PyTreeCodeGen
|
||||
from torch.fx.graph_module import _exec_with_source, _forward_from_src, _WrappedCall
|
||||
from torch.fx.graph import PythonCode
|
||||
|
||||
try:
|
||||
from torch.fx.graph import _PyTreeCodeGen
|
||||
SUPPORT_PT_CODEGEN = True
|
||||
except ImportError:
|
||||
SUPPORT_PT_CODEGEN = False
|
||||
|
||||
from torch.fx.graph_module import _exec_with_source, _forward_from_src
|
||||
from torch.nn.modules.module import _addindent
|
||||
|
||||
|
||||
# This is a copy of torch.fx.graph_module._WrappedCall.
|
||||
# It should be removed when we stop supporting torch < 1.12.0.
|
||||
class _WrappedCall:
|
||||
|
||||
def __init__(self, cls, cls_call):
|
||||
self.cls = cls
|
||||
self.cls_call = cls_call
|
||||
|
||||
# Previously, if an error occurred when valid
|
||||
# symbolically-traced code was run with an invalid input, the
|
||||
# user would see the source of the error as coming from
|
||||
# `File "<eval_with_key_N">`, where N is some number. We use
|
||||
# this function to generate a more informative error message. We
|
||||
# return the traceback itself, a message explaining that the
|
||||
# error occurred in a traced Module's generated forward
|
||||
# function, and five lines of context surrounding the faulty
|
||||
# line
|
||||
@staticmethod
|
||||
def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
|
||||
# auxiliary variables (for readability)
|
||||
err_lineno = frame_summary.lineno
|
||||
assert err_lineno is not None
|
||||
line = frame_summary.line
|
||||
assert line is not None
|
||||
err_line_len = len(line)
|
||||
all_src_lines = linecache.getlines(frame_summary.filename)
|
||||
|
||||
# constituent substrings of the error message
|
||||
tb_repr = traceback.format_exc()
|
||||
custom_msg = ("Call using an FX-traced Module, "
|
||||
f"line {err_lineno} of the traced Module's "
|
||||
"generated forward function:")
|
||||
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
|
||||
marker = "~" * err_line_len + "~~~ <--- HERE"
|
||||
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
|
||||
|
||||
# joined message
|
||||
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
|
||||
|
||||
def __call__(self, obj, *args, **kwargs):
|
||||
try:
|
||||
if self.cls_call is not None:
|
||||
return self.cls_call(obj, *args, **kwargs)
|
||||
else:
|
||||
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||
except Exception as e:
|
||||
assert e.__traceback__
|
||||
topmost_framesummary: traceback.FrameSummary = \
|
||||
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
|
||||
if "eval_with_key" in topmost_framesummary.filename:
|
||||
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
|
||||
raise e.with_traceback(None)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
class ColoGraphModule(torch.fx.GraphModule):
|
||||
"""
|
||||
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
|
||||
|
@ -65,7 +131,7 @@ class ColoGraphModule(torch.fx.GraphModule):
|
|||
called after editing the contained ``graph``, otherwise the generated
|
||||
code of this ``GraphModule`` will be out of date.
|
||||
"""
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
python_code = self._graph.python_code(root_module='self')
|
||||
|
|
|
@ -20,7 +20,7 @@ def union(a, b):
|
|||
return {**a, **b}
|
||||
|
||||
|
||||
def compute_size_in_bytes(elem: torch.Tensor | Dict | List | Tuple | int) -> int:
|
||||
def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||
"""Compute the size of a tensor or a collection of tensors in bytes.
|
||||
|
||||
Args:
|
||||
|
@ -195,8 +195,8 @@ class MetaInfo:
|
|||
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
||||
if self.output_size:
|
||||
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
||||
if self.total_size:
|
||||
s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
||||
# if self.total_size:
|
||||
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
||||
if self.temp_size:
|
||||
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
||||
if self.backward_size:
|
||||
|
|
|
@ -111,7 +111,24 @@ class ShapeProp(torch.fx.Interpreter):
|
|||
with self.global_hook:
|
||||
r = getattr(self, n.op)(n.target, args, kwargs)
|
||||
|
||||
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
||||
def unwrap_fn(elem):
|
||||
|
||||
def _convert_meta(t: torch.Tensor):
|
||||
if t.device == 'meta':
|
||||
return t
|
||||
else:
|
||||
return t.to('meta')
|
||||
|
||||
if isinstance(elem, MetaTensor):
|
||||
return _convert_meta(elem._tensor)
|
||||
|
||||
elif isinstance(elem, torch.Tensor):
|
||||
return _convert_meta(elem)
|
||||
|
||||
else:
|
||||
return elem
|
||||
|
||||
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
||||
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
||||
n_info = MetaInfo(n)
|
||||
n_info.outputs = _normalize_tuple(r)
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .bias_addition import *
|
||||
from .custom_leaf_module import *
|
|
@ -4,11 +4,10 @@ graph construction to deal with the compatibility between bias-addition and all-
|
|||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.utils import _pair, _single, _triple
|
||||
|
||||
from .symbolic_trace import register_tracer_impl
|
||||
from .tracer import register_tracer_impl
|
||||
|
||||
__all__ = []
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
import torch
|
||||
|
||||
from .tracer import register_leaf_module, register_leaf_module_impl
|
||||
|
||||
try:
|
||||
import apex
|
||||
register_leaf_module(apex.normalization.FusedLayerNorm)
|
||||
register_leaf_module(apex.normalization.FusedRMSNorm)
|
||||
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
|
||||
register_leaf_module(apex.normalization.MixedFusedRMSNorm)
|
||||
|
||||
@register_leaf_module_impl(apex.normalization.FusedLayerNorm)
|
||||
@register_leaf_module_impl(apex.normalization.FusedRMSNorm)
|
||||
@register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm)
|
||||
@register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm)
|
||||
def torch_nn_normalize(self, input: torch.Tensor):
|
||||
# check shape
|
||||
if isinstance(self, torch.nn.BatchNorm1d):
|
||||
assert input.dim() in [2, 3]
|
||||
elif isinstance(self, torch.nn.BatchNorm2d):
|
||||
assert input.dim() == 4
|
||||
elif isinstance(self, torch.nn.BatchNorm3d):
|
||||
assert input.dim() == 5
|
||||
|
||||
# normalization maintain the same shape as the input
|
||||
return input.clone()
|
||||
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
|
@ -0,0 +1,112 @@
|
|||
import operator
|
||||
from typing import Any, Callable, Dict, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import Graph, Node, Proxy, Tracer
|
||||
from torch.fx.graph import _Namespace
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
|
||||
Target = Union[Callable[..., Any], str]
|
||||
|
||||
|
||||
class ColoProxy(Proxy):
|
||||
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._meta_data = data
|
||||
|
||||
@property
|
||||
def meta_data(self):
|
||||
return self._meta_data
|
||||
|
||||
@meta_data.setter
|
||||
def meta_data(self, args):
|
||||
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||
self._meta_data = tree_map(wrap_fn, args)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
if orig_method in cls._func_dispatch:
|
||||
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
||||
proxy = impl(*args, **kwargs)
|
||||
cls._func_dispatch[orig_method] = impl
|
||||
return proxy
|
||||
else:
|
||||
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
|
||||
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
||||
if proxy.meta_data is None:
|
||||
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||
return proxy
|
||||
|
||||
@classmethod
|
||||
def from_torch_proxy(cls, proxy: Proxy):
|
||||
return cls(proxy.node, proxy.tracer)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_data)
|
||||
|
||||
def __int__(self):
|
||||
return int(self.meta_data)
|
||||
|
||||
def __index__(self):
|
||||
try:
|
||||
return int(self.meta_data)
|
||||
except:
|
||||
return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
|
||||
|
||||
def __float__(self):
|
||||
return float(self.meta_data)
|
||||
|
||||
def __bool__(self):
|
||||
return self.meta_data
|
||||
|
||||
def __getattr__(self, k):
|
||||
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
|
||||
proxy.meta_data = self._meta_data
|
||||
return proxy
|
||||
|
||||
def __contains__(self, key):
|
||||
if self.node.op == "placeholder":
|
||||
# this is used to handle like
|
||||
# if x in kwargs
|
||||
# we don't handle this case for now
|
||||
return False
|
||||
return super().__contains__(key)
|
||||
|
||||
def __isinstancecheck__(self, type):
|
||||
return isinstance(self.meta_data, type)
|
||||
|
||||
|
||||
class ColoAttribute(ColoProxy):
|
||||
|
||||
def __init__(self, root, attr: str, data=None):
|
||||
self.root = root
|
||||
self.attr = attr
|
||||
self.tracer = root.tracer
|
||||
self._meta_data = data
|
||||
self._node: Optional[Node] = None
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ColoAttribute({self.node.name}, attr={self.attr})"
|
|
@ -0,0 +1,157 @@
|
|||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import Tracer
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
|
||||
try:
|
||||
from ..codegen import ActivationCheckpointCodeGen
|
||||
SUPPORT_ACTIVATION = True
|
||||
except:
|
||||
SUPPORT_ACTIVATION = False
|
||||
from ..graph_module import ColoGraphModule
|
||||
from .tracer import ColoTracer
|
||||
|
||||
|
||||
def _default_device():
|
||||
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
|
||||
def _current_device(module: torch.nn.Module):
|
||||
try:
|
||||
return next(module.parameters()).device
|
||||
except:
|
||||
return _default_device()
|
||||
|
||||
|
||||
def symbolic_trace(
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
meta_args: Optional[Dict[str, Any]] = None,
|
||||
trace_act_ckpt: bool = False,
|
||||
bias_addition_split: bool = False,
|
||||
) -> ColoGraphModule:
|
||||
"""
|
||||
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
|
||||
attached to the ``Node``s.
|
||||
|
||||
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
|
||||
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
|
||||
|
||||
This tracer is able to trace basic control flow and for loops.
|
||||
|
||||
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
|
||||
(See ./bias_addition.py for more details).
|
||||
|
||||
Examples:
|
||||
1. Tracing a ``torch.nn.Module`` with control flow.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
if x.size(0) > 1:
|
||||
x = x.sum(dim=0)
|
||||
return self.linear(x)
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
|
||||
|
||||
# traced code like:
|
||||
# def forward(self, x):
|
||||
# linear_1 = self.linear(x)
|
||||
# return linear_1
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
|
||||
|
||||
# traced code like:
|
||||
# def forward(self, x):
|
||||
# sum = x.sum(dim=0); x = None
|
||||
# linear = self.linear(sum); sum = None
|
||||
# return linear
|
||||
|
||||
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
def custom_forward(x):
|
||||
return self.linear(x)
|
||||
return torch.utils.checkpoint.checkpoint(custom_forward, x)
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
|
||||
|
||||
# traced code like:
|
||||
# def checkpoint_0(self, x):
|
||||
# linear = self.linear(x); x = None
|
||||
# return linear
|
||||
#
|
||||
# def forward(self, x):
|
||||
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
|
||||
# return linear
|
||||
|
||||
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
|
||||
|
||||
# traced code like:
|
||||
# def forward(self, x):
|
||||
# linear_bias = self.linear.bias
|
||||
# linear_weight = self.linear.weight
|
||||
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
||||
# add = linear + linear_bias; linear = linear_bias = None
|
||||
# return add
|
||||
|
||||
Args:
|
||||
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
|
||||
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
|
||||
Defaults to {}.
|
||||
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
|
||||
for tracing control flow. Defaults to {}.
|
||||
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
|
||||
Defaults to False.
|
||||
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
|
||||
|
||||
Remarks:
|
||||
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
|
||||
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
|
||||
repo. We welcome any feedback and contributions to enhance the extensibility of
|
||||
Colossal-AI.
|
||||
"""
|
||||
if meta_args:
|
||||
device, orig_device = _default_device(), _current_device(root)
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
||||
bias_addition_split=bias_addition_split).trace(root.to(device),
|
||||
concrete_args=concrete_args,
|
||||
meta_args=tree_map(wrap_fn, meta_args))
|
||||
if trace_act_ckpt and SUPPORT_ACTIVATION:
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
root.to(orig_device)
|
||||
else:
|
||||
graph = Tracer().trace(root, concrete_args=concrete_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
return ColoGraphModule(root, graph, name)
|
|
@ -1,28 +1,19 @@
|
|||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import Graph, Node, Proxy, Tracer
|
||||
from torch.fx.graph import _Namespace
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor, _TensorPropertyMethod, _TorchFactoryMethod
|
||||
from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod
|
||||
|
||||
from .codegen import ActivationCheckpointCodeGen
|
||||
from .graph_module import ColoGraphModule
|
||||
from .node_util import MetaInfo
|
||||
from ..node_util import MetaInfo
|
||||
from .proxy import ColoProxy
|
||||
|
||||
Target = Union[Callable[..., Any], str]
|
||||
Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
|
||||
List[Any], # actually Argument
|
||||
Dict[str, Any], # actually Argument
|
||||
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
|
||||
'Node',]]
|
||||
zeros = torch.zeros
|
||||
|
||||
|
||||
def _truncate_suffix(s: str):
|
||||
|
@ -32,17 +23,6 @@ def _truncate_suffix(s: str):
|
|||
return re.sub(r'_\d+$', '', s)
|
||||
|
||||
|
||||
def _default_device():
|
||||
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
|
||||
def _current_device(module):
|
||||
try:
|
||||
return next(module.parameters()).device
|
||||
except:
|
||||
return _default_device()
|
||||
|
||||
|
||||
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
|
||||
|
||||
def wrapper(impl):
|
||||
|
@ -70,149 +50,6 @@ def register_non_leaf_module(module: nn.Module):
|
|||
ColoTracer._custom_non_leaf_module.add(module)
|
||||
|
||||
|
||||
class ColoProxy(Proxy):
|
||||
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._meta_data = data
|
||||
|
||||
@property
|
||||
def meta_data(self):
|
||||
return self._meta_data
|
||||
|
||||
@meta_data.setter
|
||||
def meta_data(self, args):
|
||||
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||
self._meta_data = tree_map(wrap_fn, args)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
if orig_method in cls._func_dispatch:
|
||||
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
||||
proxy = impl(*args, **kwargs)
|
||||
cls._func_dispatch[orig_method] = impl
|
||||
return proxy
|
||||
else:
|
||||
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
|
||||
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
||||
if proxy.meta_data is None:
|
||||
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||
return proxy
|
||||
|
||||
@classmethod
|
||||
def from_torch_proxy(cls, proxy: Proxy):
|
||||
return cls(proxy.node, proxy.tracer)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta_data)
|
||||
|
||||
def __int__(self):
|
||||
return int(self.meta_data)
|
||||
|
||||
def __index__(self):
|
||||
try:
|
||||
return int(self.meta_data)
|
||||
except:
|
||||
return zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
|
||||
|
||||
def __float__(self):
|
||||
return float(self.meta_data)
|
||||
|
||||
def __bool__(self):
|
||||
return self.meta_data
|
||||
|
||||
def __getattr__(self, k):
|
||||
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
|
||||
proxy.meta_data = self._meta_data
|
||||
return proxy
|
||||
|
||||
def __contains__(self, key):
|
||||
if self.node.op == "placeholder":
|
||||
# this is used to handle like
|
||||
# if x in kwargs
|
||||
# we don't handle this case for now
|
||||
return False
|
||||
return super().__contains__(key)
|
||||
|
||||
def __isinstancecheck__(self, type):
|
||||
return isinstance(self.meta_data, type)
|
||||
|
||||
def size(self, dim=None):
|
||||
if self._meta_data is None:
|
||||
return self._meta_data.size(*[dim] if dim else [])
|
||||
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
|
||||
|
||||
def dim(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.dim()
|
||||
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.shape
|
||||
return self.tracer.create_proxy('call_function', getattr, (self, 'shape'), {})
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.ndim
|
||||
return self.tracer.create_proxy('call_function', getattr, (self, 'ndim'), {})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.device
|
||||
return self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if self._meta_data is not None:
|
||||
return self._meta_data.dtype
|
||||
return self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
|
||||
|
||||
def cpu(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
|
||||
|
||||
def cuda(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
|
||||
|
||||
|
||||
class ColoAttribute(ColoProxy):
|
||||
|
||||
def __init__(self, root, attr: str, data=None):
|
||||
self.root = root
|
||||
self.attr = attr
|
||||
self.tracer = root.tracer
|
||||
self._meta_data = data
|
||||
self._node: Optional[Node] = None
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ColoAttribute({self.node.name}, attr={self.attr})"
|
||||
|
||||
|
||||
class ColoTracer(Tracer):
|
||||
_custom_leaf_module: Set[Type[nn.Module]] = set()
|
||||
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
|
||||
|
@ -249,7 +86,6 @@ class ColoTracer(Tracer):
|
|||
# we will enter the module and split the bias-addition ops
|
||||
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
|
||||
return False
|
||||
|
||||
# user can specify which modules are leaf modules and which are not
|
||||
return (type(m) not in self._custom_non_leaf_module
|
||||
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
|
||||
|
@ -306,9 +142,13 @@ class ColoTracer(Tracer):
|
|||
mod = self.root.get_submodule(target)
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
proxy.meta_data = self._custom_leaf_module_impl.get(type(mod),
|
||||
mod.forward)(*tree_map(unwrap_fn, args),
|
||||
**tree_map(unwrap_fn, kwargs))
|
||||
args = tree_map(unwrap_fn, args)
|
||||
kwargs = tree_map(unwrap_fn, kwargs)
|
||||
if type(mod) in self._custom_leaf_module:
|
||||
target = self._custom_leaf_module_impl[type(mod)]
|
||||
proxy.meta_data = target(mod, *args, **kwargs)
|
||||
else:
|
||||
proxy.meta_data = mod.forward(*args, **kwargs)
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
return proxy
|
||||
|
@ -320,15 +160,21 @@ class ColoTracer(Tracer):
|
|||
|
||||
def trace(self,
|
||||
root: torch.nn.Module,
|
||||
concrete_args: Optional[Dict[str, torch.Tensor]] = {},
|
||||
meta_args: Optional[Dict[str, torch.Tensor]] = {}) -> Graph:
|
||||
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
|
||||
|
||||
if meta_args is None:
|
||||
meta_args = {}
|
||||
|
||||
if concrete_args is None:
|
||||
concrete_args = {}
|
||||
|
||||
# check concrete and meta args have valid names
|
||||
sig = inspect.signature(root.forward)
|
||||
sig_names = set(sig.parameters.keys())
|
||||
meta_arg_names = set(meta_args.keys())
|
||||
concrete_arg_names = set(concrete_args.keys())
|
||||
|
||||
non_concrete_arg_names = sig_names - concrete_arg_names
|
||||
# update concrete args with default values
|
||||
for k, v in sig.parameters.items():
|
||||
if k in sig_names - meta_arg_names and \
|
||||
|
@ -352,6 +198,34 @@ class ColoTracer(Tracer):
|
|||
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||
self.mod_dir = ''
|
||||
self.graph.lint()
|
||||
|
||||
for node in self.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
# Removing default values for inputs as the forward pass will fail with them.
|
||||
if node.target in non_concrete_arg_names:
|
||||
node.args = ()
|
||||
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
||||
# It cannot infer on the attributes and methods the input should have, and fails.
|
||||
node.type = torch.Tensor
|
||||
# It is a concrete arg so it is not used and should be removed.
|
||||
else:
|
||||
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
||||
# Newer versions of torch.fx emit an assert statement
|
||||
# for concrete arguments; delete those before we delete
|
||||
# the concrete arg.
|
||||
to_delete = []
|
||||
for user in node.users:
|
||||
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
||||
to_delete.append(user)
|
||||
for user in to_delete:
|
||||
self.graph.erase_node(user)
|
||||
|
||||
self.graph.erase_node(node)
|
||||
|
||||
# TODO: solves GraphModule creation.
|
||||
# Without this, return type annotation "Tuple" is causing code execution failure.
|
||||
if node.op == "output":
|
||||
node.type = None
|
||||
return self.graph
|
||||
|
||||
@contextmanager
|
||||
|
@ -454,7 +328,7 @@ class ColoTracer(Tracer):
|
|||
if node.op == "output":
|
||||
node.type = None
|
||||
self.graph.lint()
|
||||
|
||||
|
||||
def getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
|
||||
|
@ -487,134 +361,3 @@ class ColoTracer(Tracer):
|
|||
return maybe_parameter_proxy
|
||||
|
||||
return attr_val
|
||||
|
||||
|
||||
def symbolic_trace(
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = {},
|
||||
meta_args: Optional[Dict[str, Any]] = {},
|
||||
trace_act_ckpt: bool = False,
|
||||
bias_addition_split: bool = False,
|
||||
) -> ColoGraphModule:
|
||||
"""
|
||||
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
|
||||
attached to the ``Node``s.
|
||||
|
||||
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
|
||||
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
|
||||
|
||||
This tracer is able to trace basic control flow and for loops.
|
||||
|
||||
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
|
||||
(See ./bias_addition.py for more details).
|
||||
|
||||
Examples:
|
||||
1. Tracing a ``torch.nn.Module`` with control flow.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
if x.size(0) > 1:
|
||||
x = x.sum(dim=0)
|
||||
return self.linear(x)
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
|
||||
|
||||
# traced code like:
|
||||
# def forward(self, x):
|
||||
# linear_1 = self.linear(x)
|
||||
# return linear_1
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
|
||||
|
||||
# traced code like:
|
||||
# def forward(self, x):
|
||||
# sum = x.sum(dim=0); x = None
|
||||
# linear = self.linear(sum); sum = None
|
||||
# return linear
|
||||
|
||||
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x):
|
||||
def custom_forward(x):
|
||||
return self.linear(x)
|
||||
return torch.utils.checkpoint.checkpoint(custom_forward, x)
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
|
||||
|
||||
# traced code like:
|
||||
# def checkpoint_0(self, x):
|
||||
# linear = self.linear(x); x = None
|
||||
# return linear
|
||||
#
|
||||
# def forward(self, x):
|
||||
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
|
||||
# return linear
|
||||
|
||||
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
|
||||
|
||||
# traced code like:
|
||||
# def forward(self, x):
|
||||
# linear_bias = self.linear.bias
|
||||
# linear_weight = self.linear.weight
|
||||
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
||||
# add = linear + linear_bias; linear = linear_bias = None
|
||||
# return add
|
||||
|
||||
Args:
|
||||
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
|
||||
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
|
||||
Defaults to {}.
|
||||
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
|
||||
for tracing control flow. Defaults to {}.
|
||||
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
|
||||
Defaults to False.
|
||||
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
|
||||
|
||||
Remarks:
|
||||
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
|
||||
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
|
||||
repo. We welcome any feedback and contributions to enhance the extensibility of
|
||||
Colossal-AI.
|
||||
"""
|
||||
if meta_args:
|
||||
device, orig_device = _default_device(), _current_device(root)
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
||||
bias_addition_split=bias_addition_split).trace(root.to(device),
|
||||
concrete_args=concrete_args,
|
||||
meta_args=tree_map(wrap_fn, meta_args))
|
||||
if trace_act_ckpt:
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
root.to(orig_device)
|
||||
else:
|
||||
graph = Tracer().trace(root, concrete_args=concrete_args)
|
||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||
return ColoGraphModule(root, graph, name)
|
|
@ -1,5 +1,4 @@
|
|||
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
|
||||
|
||||
from .registry import model_zoo
|
||||
|
||||
__all__ = ['model_zoo']
|
||||
|
|
|
@ -17,6 +17,14 @@ def data_gen():
|
|||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def seq_classification_data_gen():
|
||||
# batch sizes should be 1 if no padding token is defined.
|
||||
input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
||||
|
@ -44,6 +52,6 @@ model_zoo.register(name='transformers_gpt_for_token_classification',
|
|||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
data_gen_fn=seq_classification_data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
try:
|
||||
|
@ -73,7 +74,7 @@ class AddmmModel(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||
|
|
|
@ -3,7 +3,8 @@ from numpy import isin
|
|||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
# from colossalai.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_model_and_compare_output(model, data_gen):
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
@ -6,6 +9,7 @@ BATCH_SIZE = 2
|
|||
SEQ_LENGTH = 16
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_albert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_albert')
|
||||
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_bert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_bert')
|
||||
|
||||
|
|
|
@ -1,16 +1,24 @@
|
|||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
# TODO: remove this skip once we handle the latest gpt model
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_gpt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
|
||||
# TODO: support the following models
|
||||
# 1. GPT2DoubleHeadsModel
|
||||
# as they are not supported, let's skip them
|
||||
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
|
||||
continue
|
||||
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_opt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
||||
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_t5():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import pytest
|
||||
import timm.models as tm
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
|
@ -42,6 +42,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
|||
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_timm_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
|
|
@ -1,20 +1,18 @@
|
|||
import re
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torchaudio_utils import trace_and_compare
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_torchaudio_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
# FIXME(ver217): temporarily skip these models
|
||||
if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name):
|
||||
continue
|
||||
model = model_fn()
|
||||
trace_and_compare(model,
|
||||
data_gen_fn,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
BATCH = 2
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
BATCH = 2
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue