mirror of https://github.com/hpcaitech/ColossalAI
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polishpull/3197/head
parent
b429529365
commit
f57d34958b
|
@ -6,11 +6,15 @@
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
try:
|
||||||
|
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||||
|
except AttributeError:
|
||||||
|
meta_lib = None
|
||||||
|
|
||||||
meta_table = {}
|
meta_table = {}
|
||||||
|
|
||||||
|
@ -50,432 +54,411 @@ def register_meta(op, register_dispatcher=True):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
# ============================== Convolutions ======================================
|
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||||
# https://github.com/pytorch/pytorch/pull/79834
|
# ============================== Convolutions ======================================
|
||||||
@register_meta(aten.convolution.default)
|
# https://github.com/pytorch/pytorch/pull/79834
|
||||||
def meta_conv(
|
@register_meta(aten.convolution.default)
|
||||||
input_tensor: torch.Tensor,
|
def meta_conv(
|
||||||
weight: torch.Tensor,
|
input_tensor: torch.Tensor,
|
||||||
bias: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
stride: List[int],
|
bias: torch.Tensor,
|
||||||
padding: List[int],
|
stride: List[int],
|
||||||
dilation: List[int],
|
padding: List[int],
|
||||||
is_transposed: bool,
|
dilation: List[int],
|
||||||
output_padding: List[int],
|
is_transposed: bool,
|
||||||
groups: int,
|
output_padding: List[int],
|
||||||
):
|
groups: int,
|
||||||
|
|
||||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
|
||||||
"""
|
|
||||||
Formula to apply to calculate the length of some dimension of the output
|
|
||||||
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
|
||||||
Args:
|
|
||||||
ln: length of the dimension
|
|
||||||
p: padding in that dim
|
|
||||||
d: dilation in that dim
|
|
||||||
k: kernel size in that dim
|
|
||||||
s: stride in that dim
|
|
||||||
Returns:
|
|
||||||
The output length
|
|
||||||
"""
|
|
||||||
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
|
|
||||||
|
|
||||||
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
|
|
||||||
"""
|
|
||||||
Formula to apply to calculate the length of some dimension of the output
|
|
||||||
if transposed convolution is used.
|
|
||||||
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
|
||||||
Args:
|
|
||||||
ln: length of the dimension
|
|
||||||
p: padding in that dim
|
|
||||||
d: dilation in that dim
|
|
||||||
k: kernel size in that dim
|
|
||||||
s: stride in that dim
|
|
||||||
op: output padding in that dim
|
|
||||||
Returns:
|
|
||||||
The output length
|
|
||||||
"""
|
|
||||||
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
|
|
||||||
|
|
||||||
def calc_conv_nd_return_shape(
|
|
||||||
dims: torch.Size,
|
|
||||||
kernel_size: torch.Size,
|
|
||||||
stride: Union[List[int], int],
|
|
||||||
padding: Union[List[int], int],
|
|
||||||
dilation: Union[List[int], int],
|
|
||||||
output_padding: Optional[Union[List[int], int]] = None,
|
|
||||||
):
|
):
|
||||||
ret_shape = []
|
|
||||||
if isinstance(stride, int):
|
|
||||||
stride = [stride] * len(dims)
|
|
||||||
elif len(stride) == 1:
|
|
||||||
stride = [stride[0]] * len(dims)
|
|
||||||
|
|
||||||
if isinstance(padding, int):
|
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||||
padding = [padding] * len(dims)
|
"""
|
||||||
elif len(padding) == 1:
|
Formula to apply to calculate the length of some dimension of the output
|
||||||
padding = [padding[0]] * len(dims)
|
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||||
|
Args:
|
||||||
|
ln: length of the dimension
|
||||||
|
p: padding in that dim
|
||||||
|
d: dilation in that dim
|
||||||
|
k: kernel size in that dim
|
||||||
|
s: stride in that dim
|
||||||
|
Returns:
|
||||||
|
The output length
|
||||||
|
"""
|
||||||
|
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
|
||||||
|
|
||||||
if isinstance(dilation, int):
|
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
|
||||||
dilation = [dilation] * len(dims)
|
"""
|
||||||
elif len(dilation) == 1:
|
Formula to apply to calculate the length of some dimension of the output
|
||||||
dilation = [dilation[0]] * len(dims)
|
if transposed convolution is used.
|
||||||
|
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||||
|
Args:
|
||||||
|
ln: length of the dimension
|
||||||
|
p: padding in that dim
|
||||||
|
d: dilation in that dim
|
||||||
|
k: kernel size in that dim
|
||||||
|
s: stride in that dim
|
||||||
|
op: output padding in that dim
|
||||||
|
Returns:
|
||||||
|
The output length
|
||||||
|
"""
|
||||||
|
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
|
||||||
|
|
||||||
output_padding_list: Optional[List[int]] = None
|
def calc_conv_nd_return_shape(
|
||||||
if output_padding:
|
dims: torch.Size,
|
||||||
if isinstance(output_padding, int):
|
kernel_size: torch.Size,
|
||||||
output_padding_list = [output_padding] * len(dims)
|
stride: Union[List[int], int],
|
||||||
elif len(output_padding) == 1:
|
padding: Union[List[int], int],
|
||||||
output_padding_list = [output_padding[0]] * len(dims)
|
dilation: Union[List[int], int],
|
||||||
else:
|
output_padding: Optional[Union[List[int], int]] = None,
|
||||||
output_padding_list = output_padding
|
):
|
||||||
|
ret_shape = []
|
||||||
|
if isinstance(stride, int):
|
||||||
|
stride = [stride] * len(dims)
|
||||||
|
elif len(stride) == 1:
|
||||||
|
stride = [stride[0]] * len(dims)
|
||||||
|
|
||||||
for i in range(len(dims)):
|
if isinstance(padding, int):
|
||||||
# If output_padding is present, we are dealing with a transposed convolution
|
padding = [padding] * len(dims)
|
||||||
if output_padding_list:
|
elif len(padding) == 1:
|
||||||
ret_shape.append(
|
padding = [padding[0]] * len(dims)
|
||||||
_formula_transposed(
|
|
||||||
dims[i],
|
|
||||||
padding[i],
|
|
||||||
dilation[i],
|
|
||||||
kernel_size[i],
|
|
||||||
stride[i],
|
|
||||||
output_padding_list[i],
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
|
||||||
return ret_shape
|
|
||||||
|
|
||||||
def pick_memory_format():
|
if isinstance(dilation, int):
|
||||||
if input_tensor.is_contiguous(memory_format=torch.channels_last):
|
dilation = [dilation] * len(dims)
|
||||||
return torch.channels_last
|
elif len(dilation) == 1:
|
||||||
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
dilation = [dilation[0]] * len(dims)
|
||||||
return torch.contiguous_format
|
|
||||||
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
|
||||||
return torch.preserve_format
|
|
||||||
|
|
||||||
kernel_size = weight.shape[2:]
|
output_padding_list: Optional[List[int]] = None
|
||||||
dims = input_tensor.shape[2:]
|
if output_padding:
|
||||||
if is_transposed:
|
if isinstance(output_padding, int):
|
||||||
out_channels = groups * weight.shape[1]
|
output_padding_list = [output_padding] * len(dims)
|
||||||
|
elif len(output_padding) == 1:
|
||||||
|
output_padding_list = [output_padding[0]] * len(dims)
|
||||||
|
else:
|
||||||
|
output_padding_list = output_padding
|
||||||
|
|
||||||
shape_out = calc_conv_nd_return_shape(
|
for i in range(len(dims)):
|
||||||
dims,
|
# If output_padding is present, we are dealing with a transposed convolution
|
||||||
kernel_size,
|
if output_padding_list:
|
||||||
stride,
|
ret_shape.append(
|
||||||
padding,
|
_formula_transposed(
|
||||||
dilation,
|
dims[i],
|
||||||
output_padding,
|
padding[i],
|
||||||
)
|
dilation[i],
|
||||||
|
kernel_size[i],
|
||||||
|
stride[i],
|
||||||
|
output_padding_list[i],
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
||||||
|
return ret_shape
|
||||||
|
|
||||||
else:
|
def pick_memory_format():
|
||||||
out_channels = weight.shape[0]
|
if input_tensor.is_contiguous(memory_format=torch.channels_last):
|
||||||
if weight.shape[1] != input_tensor.shape[1] / groups:
|
return torch.channels_last
|
||||||
raise RuntimeError("Invalid channel dimensions")
|
elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
|
||||||
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
return torch.contiguous_format
|
||||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
|
||||||
mem_fmt = pick_memory_format()
|
return torch.preserve_format
|
||||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
kernel_size = weight.shape[2:]
|
||||||
|
dims = input_tensor.shape[2:]
|
||||||
|
if is_transposed:
|
||||||
|
out_channels = groups * weight.shape[1]
|
||||||
|
|
||||||
@register_meta(aten._convolution.default)
|
shape_out = calc_conv_nd_return_shape(
|
||||||
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
dims,
|
||||||
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
kernel_size,
|
||||||
*extra_args):
|
stride,
|
||||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
padding,
|
||||||
return out
|
dilation,
|
||||||
|
output_padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
out_channels = weight.shape[0]
|
||||||
|
if weight.shape[1] != input_tensor.shape[1] / groups:
|
||||||
|
raise RuntimeError("Invalid channel dimensions")
|
||||||
|
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
||||||
|
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||||
|
mem_fmt = pick_memory_format()
|
||||||
|
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||||
|
return out
|
||||||
|
|
||||||
@register_meta(aten.convolution_backward.default)
|
@register_meta(aten._convolution.default)
|
||||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||||
return new_like(input), new_like(weight), new((bias_sizes))
|
*extra_args):
|
||||||
|
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@register_meta(aten.convolution_backward.default)
|
||||||
|
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||||
|
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||||
|
return new_like(input), new_like(weight), new((bias_sizes))
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||||
def meta_adaptive_avg_pool2d_backward(
|
def meta_adaptive_avg_pool2d_backward(
|
||||||
grad_output: torch.Tensor,
|
grad_output: torch.Tensor,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
):
|
):
|
||||||
return new_like(input)
|
return new_like(input)
|
||||||
|
|
||||||
|
# ================================ RNN =============================================
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||||
|
@register_meta(aten._cudnn_rnn.default)
|
||||||
|
def meta_cuda_rnn(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
weight_stride0,
|
||||||
|
weight_buf,
|
||||||
|
hx,
|
||||||
|
cx,
|
||||||
|
mode,
|
||||||
|
hidden_size,
|
||||||
|
proj_size,
|
||||||
|
num_layers,
|
||||||
|
batch_first,
|
||||||
|
dropout,
|
||||||
|
train,
|
||||||
|
bidirectional,
|
||||||
|
batch_sizes,
|
||||||
|
dropout_state,
|
||||||
|
):
|
||||||
|
|
||||||
# ================================ RNN =============================================
|
is_input_packed = len(batch_sizes) != 0
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
if is_input_packed:
|
||||||
@register_meta(aten._cudnn_rnn.default)
|
seq_length = len(batch_sizes)
|
||||||
def meta_cuda_rnn(
|
mini_batch = batch_sizes[0]
|
||||||
input,
|
batch_sizes_sum = input.shape[0]
|
||||||
weight,
|
else:
|
||||||
weight_stride0,
|
seq_length = input.shape[1] if batch_first else input.shape[0]
|
||||||
weight_buf,
|
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
||||||
hx,
|
batch_sizes_sum = -1
|
||||||
cx,
|
|
||||||
mode,
|
|
||||||
hidden_size,
|
|
||||||
proj_size,
|
|
||||||
num_layers,
|
|
||||||
batch_first,
|
|
||||||
dropout,
|
|
||||||
train,
|
|
||||||
bidirectional,
|
|
||||||
batch_sizes,
|
|
||||||
dropout_state,
|
|
||||||
):
|
|
||||||
|
|
||||||
is_input_packed = len(batch_sizes) != 0
|
num_directions = 2 if bidirectional else 1
|
||||||
if is_input_packed:
|
out_size = proj_size if proj_size != 0 else hidden_size
|
||||||
seq_length = len(batch_sizes)
|
if is_input_packed:
|
||||||
mini_batch = batch_sizes[0]
|
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||||
batch_sizes_sum = input.shape[0]
|
else:
|
||||||
else:
|
out_shape = ([mini_batch, seq_length, out_size *
|
||||||
seq_length = input.shape[1] if batch_first else input.shape[0]
|
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
||||||
mini_batch = input.shape[0] if batch_first else input.shape[1]
|
output = input.new_empty(out_shape)
|
||||||
batch_sizes_sum = -1
|
|
||||||
|
|
||||||
num_directions = 2 if bidirectional else 1
|
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||||
out_size = proj_size if proj_size != 0 else hidden_size
|
cy = new(0) if cx is None else cx.new_empty(cell_shape)
|
||||||
if is_input_packed:
|
|
||||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
|
||||||
else:
|
|
||||||
out_shape = ([mini_batch, seq_length, out_size *
|
|
||||||
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
|
||||||
output = input.new_empty(out_shape)
|
|
||||||
|
|
||||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
|
||||||
cy = new(0) if cx is None else cx.new_empty(cell_shape)
|
|
||||||
|
|
||||||
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
|
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
|
||||||
|
reserve_shape = 0 if train else 0
|
||||||
|
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
|
||||||
|
|
||||||
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
|
return output, hy, cy, reserve, weight_buf
|
||||||
reserve_shape = 0 if train else 0
|
|
||||||
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
|
|
||||||
|
|
||||||
return output, hy, cy, reserve, weight_buf
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||||
|
@register_meta(aten._cudnn_rnn_backward.default)
|
||||||
|
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_stride0: int,
|
||||||
|
hx: torch.Tensor,
|
||||||
|
cx: Optional[torch.Tensor] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
|
||||||
|
()) # (grad_input, grad_weight, grad_hx, grad_cx)
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
||||||
|
# ============================== Activations =======================================
|
||||||
|
_unregistered_ewise = [
|
||||||
|
aten.relu.default,
|
||||||
|
aten.prelu.default,
|
||||||
|
aten.hardswish.default,
|
||||||
|
aten.hardtanh.default,
|
||||||
|
aten.prelu_backward.default,
|
||||||
|
aten.hardswish_backward.default,
|
||||||
|
aten.hardtanh_backward.default,
|
||||||
|
]
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
@register_meta(_unregistered_ewise)
|
||||||
@register_meta(aten._cudnn_rnn_backward.default)
|
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
||||||
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
return new_like(input)
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_stride0: int,
|
|
||||||
hx: torch.Tensor,
|
|
||||||
cx: Optional[torch.Tensor] = None,
|
|
||||||
*args,
|
|
||||||
**kwargs):
|
|
||||||
return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
|
|
||||||
()) # (grad_input, grad_weight, grad_hx, grad_cx)
|
|
||||||
|
|
||||||
|
# ============================== Normalization =====================================
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
|
@register_meta(aten.native_batch_norm.default)
|
||||||
|
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||||
|
n_input = input.size(1)
|
||||||
|
return new_like(input), new((n_input)), new((n_input))
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
# ============================== Activations =======================================
|
@register_meta(aten.native_batch_norm_backward.default)
|
||||||
_unregistered_ewise = [
|
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||||
aten.relu.default,
|
save_mean, save_invstd, train, eps, output_mask):
|
||||||
aten.prelu.default,
|
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||||
aten.hardswish.default,
|
|
||||||
aten.hardtanh.default,
|
|
||||||
aten.prelu_backward.default,
|
|
||||||
aten.hardswish_backward.default,
|
|
||||||
aten.hardtanh_backward.default,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
|
@register_meta(aten.cudnn_batch_norm.default)
|
||||||
|
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||||
|
n_input = input.size(1)
|
||||||
|
return new_like(input), new((n_input)), new((n_input)), new(
|
||||||
|
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
|
||||||
|
|
||||||
@register_meta(_unregistered_ewise)
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||||
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
# NB: CuDNN only implements the backward algorithm for batchnorm
|
||||||
return new_like(input)
|
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||||
|
# which is why this doesn't accept a 'training' parameter.
|
||||||
|
@register_meta(aten.cudnn_batch_norm_backward.default)
|
||||||
|
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||||
|
save_mean, save_invstd, eps, reserve):
|
||||||
|
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||||
|
@register_meta(aten.native_layer_norm.default)
|
||||||
|
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||||
|
bs, n_input = input.size(0), input.size(1)
|
||||||
|
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
||||||
|
|
||||||
# ============================== Normalization =====================================
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
@register_meta(aten.native_layer_norm_backward.default)
|
||||||
@register_meta(aten.native_batch_norm.default)
|
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||||
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
grad_input_mask):
|
||||||
n_input = input.size(1)
|
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
||||||
return new_like(input), new((n_input)), new((n_input))
|
|
||||||
|
|
||||||
|
# ================================== Misc ==========================================
|
||||||
|
# Maybe incorrect
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
|
||||||
|
@register_meta(aten.im2col.default)
|
||||||
|
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
||||||
|
return new_like(input)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
@register_meta(aten.native_batch_norm_backward.default)
|
@register_meta(aten.eye.m_out)
|
||||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||||
save_invstd, train, eps, output_mask):
|
return out
|
||||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
|
@register_meta(aten.roll.default)
|
||||||
|
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||||
|
return input
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
|
||||||
@register_meta(aten.cudnn_batch_norm.default)
|
@register_meta(aten._local_scalar_dense.default)
|
||||||
def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
def meta_local_scalar_dense(self: torch.Tensor):
|
||||||
n_input = input.size(1)
|
return 0
|
||||||
return new_like(input), new((n_input)), new((n_input)), new(
|
|
||||||
(0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
||||||
|
@register_meta(aten.where.self)
|
||||||
|
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
||||||
|
result_type = torch.result_type(self, other)
|
||||||
|
return new_like(condition + self + other, dtype=result_type)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
@register_meta(aten.index.Tensor)
|
||||||
# NB: CuDNN only implements the backward algorithm for batchnorm
|
def meta_index_Tensor(self, indices):
|
||||||
# in training mode (evaluation mode batchnorm has a different algorithm),
|
assert indices, "at least one index must be provided"
|
||||||
# which is why this doesn't accept a 'training' parameter.
|
# aten::index is the internal advanced indexing implementation
|
||||||
@register_meta(aten.cudnn_batch_norm_backward.default)
|
# checkIndexTensorTypes and expandTensors
|
||||||
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
result: List[Optional[torch.Tensor]] = []
|
||||||
save_mean, save_invstd, eps, reserve):
|
for i, index in enumerate(indices):
|
||||||
return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
|
if index is not None:
|
||||||
|
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||||
|
"tensors used as indices must be long, byte or bool tensors"
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
if index.dtype in [torch.int8, torch.bool]:
|
||||||
@register_meta(aten.native_layer_norm.default)
|
nonzero = index.nonzero()
|
||||||
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
k = len(result)
|
||||||
bs, n_input = input.size(0), input.size(1)
|
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||||
return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
|
for j in range(index.ndim):
|
||||||
|
assert index.shape[j] == self.shape[
|
||||||
|
k +
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||||
@register_meta(aten.native_layer_norm_backward.default)
|
result.append(nonzero.select(1, j))
|
||||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
else:
|
||||||
grad_input_mask):
|
result.append(index)
|
||||||
return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
|
|
||||||
|
|
||||||
|
|
||||||
# ================================== Misc ==========================================
|
|
||||||
# Maybe incorrect
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
|
|
||||||
@register_meta(aten.im2col.default)
|
|
||||||
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
|
||||||
return new_like(input)
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
|
||||||
@register_meta(aten.eye.m_out)
|
|
||||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
|
||||||
@register_meta(aten.roll.default)
|
|
||||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
|
||||||
return input
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
|
|
||||||
@register_meta(aten._local_scalar_dense.default)
|
|
||||||
def meta_local_scalar_dense(self: torch.Tensor):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
|
|
||||||
@register_meta(aten.where.self)
|
|
||||||
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
|
|
||||||
result_type = torch.result_type(self, other)
|
|
||||||
return new_like(condition + self + other, dtype=result_type)
|
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.index.Tensor)
|
|
||||||
def meta_index_Tensor(self, indices):
|
|
||||||
assert indices, "at least one index must be provided"
|
|
||||||
# aten::index is the internal advanced indexing implementation
|
|
||||||
# checkIndexTensorTypes and expandTensors
|
|
||||||
result: List[Optional[torch.Tensor]] = []
|
|
||||||
for i, index in enumerate(indices):
|
|
||||||
if index is not None:
|
|
||||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
|
||||||
"tensors used as indices must be long, byte or bool tensors"
|
|
||||||
if index.dtype in [torch.int8, torch.bool]:
|
|
||||||
nonzero = index.nonzero()
|
|
||||||
k = len(result)
|
|
||||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
|
||||||
for j in range(index.ndim):
|
|
||||||
assert index.shape[j] == self.shape[
|
|
||||||
k +
|
|
||||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
|
||||||
result.append(nonzero.select(1, j))
|
|
||||||
else:
|
else:
|
||||||
result.append(index)
|
result.append(index)
|
||||||
else:
|
indices = result
|
||||||
result.append(index)
|
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||||
indices = result
|
# expand_outplace
|
||||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
import torch._refs as refs
|
||||||
# expand_outplace
|
|
||||||
import torch._refs as refs
|
|
||||||
|
|
||||||
indices = list(refs._maybe_broadcast(*indices))
|
indices = list(refs._maybe_broadcast(*indices))
|
||||||
# add missing null tensors
|
# add missing null tensors
|
||||||
while len(indices) < self.ndim:
|
while len(indices) < self.ndim:
|
||||||
indices.append(None)
|
indices.append(None)
|
||||||
|
|
||||||
# hasContiguousSubspace
|
# hasContiguousSubspace
|
||||||
# true if all non-null tensors are adjacent
|
# true if all non-null tensors are adjacent
|
||||||
# See:
|
# See:
|
||||||
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||||
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
||||||
state = 0
|
state = 0
|
||||||
has_contiguous_subspace = False
|
has_contiguous_subspace = False
|
||||||
for index in indices:
|
for index in indices:
|
||||||
if state == 0:
|
if state == 0:
|
||||||
if index is not None:
|
if index is not None:
|
||||||
state = 1
|
state = 1
|
||||||
elif state == 1:
|
elif state == 1:
|
||||||
if index is None:
|
if index is None:
|
||||||
state = 2
|
state = 2
|
||||||
else:
|
|
||||||
if index is not None:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
has_contiguous_subspace = True
|
|
||||||
|
|
||||||
# transposeToFront
|
|
||||||
# This is the logic that causes the newly inserted dimensions to show up
|
|
||||||
# at the beginning of the tensor, if they're not contiguous
|
|
||||||
if not has_contiguous_subspace:
|
|
||||||
dims = []
|
|
||||||
transposed_indices = []
|
|
||||||
for i, index in enumerate(indices):
|
|
||||||
if index is not None:
|
|
||||||
dims.append(i)
|
|
||||||
transposed_indices.append(index)
|
|
||||||
for i, index in enumerate(indices):
|
|
||||||
if index is None:
|
|
||||||
dims.append(i)
|
|
||||||
transposed_indices.append(index)
|
|
||||||
self = self.permute(dims)
|
|
||||||
indices = transposed_indices
|
|
||||||
|
|
||||||
# AdvancedIndex::AdvancedIndex
|
|
||||||
# Now we can assume the indices have contiguous subspace
|
|
||||||
# This is simplified from AdvancedIndex which goes to more effort
|
|
||||||
# to put the input and indices in a form so that TensorIterator can
|
|
||||||
# take them. If we write a ref for this, probably that logic should
|
|
||||||
# get implemented
|
|
||||||
before_shape: List[int] = []
|
|
||||||
after_shape: List[int] = []
|
|
||||||
replacement_shape: List[int] = []
|
|
||||||
for dim, index in enumerate(indices):
|
|
||||||
if index is None:
|
|
||||||
if replacement_shape:
|
|
||||||
after_shape.append(self.shape[dim])
|
|
||||||
else:
|
else:
|
||||||
before_shape.append(self.shape[dim])
|
if index is not None:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
replacement_shape = list(index.shape)
|
has_contiguous_subspace = True
|
||||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
|
||||||
|
|
||||||
|
# transposeToFront
|
||||||
|
# This is the logic that causes the newly inserted dimensions to show up
|
||||||
|
# at the beginning of the tensor, if they're not contiguous
|
||||||
|
if not has_contiguous_subspace:
|
||||||
|
dims = []
|
||||||
|
transposed_indices = []
|
||||||
|
for i, index in enumerate(indices):
|
||||||
|
if index is not None:
|
||||||
|
dims.append(i)
|
||||||
|
transposed_indices.append(index)
|
||||||
|
for i, index in enumerate(indices):
|
||||||
|
if index is None:
|
||||||
|
dims.append(i)
|
||||||
|
transposed_indices.append(index)
|
||||||
|
self = self.permute(dims)
|
||||||
|
indices = transposed_indices
|
||||||
|
|
||||||
# ============================== Embedding =========================================
|
# AdvancedIndex::AdvancedIndex
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
# Now we can assume the indices have contiguous subspace
|
||||||
@register_meta(aten.embedding_dense_backward.default)
|
# This is simplified from AdvancedIndex which goes to more effort
|
||||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
# to put the input and indices in a form so that TensorIterator can
|
||||||
scale_grad_by_freq):
|
# take them. If we write a ref for this, probably that logic should
|
||||||
return new((num_weights, grad_output.size(-1)),
|
# get implemented
|
||||||
dtype=grad_output.dtype,
|
before_shape: List[int] = []
|
||||||
device=grad_output.device,
|
after_shape: List[int] = []
|
||||||
layout=grad_output.layout)
|
replacement_shape: List[int] = []
|
||||||
|
for dim, index in enumerate(indices):
|
||||||
|
if index is None:
|
||||||
|
if replacement_shape:
|
||||||
|
after_shape.append(self.shape[dim])
|
||||||
|
else:
|
||||||
|
before_shape.append(self.shape[dim])
|
||||||
|
else:
|
||||||
|
replacement_shape = list(index.shape)
|
||||||
|
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||||
|
|
||||||
|
# ============================== Embedding =========================================
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||||
|
@register_meta(aten.embedding_dense_backward.default)
|
||||||
|
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||||
|
scale_grad_by_freq):
|
||||||
|
return new((num_weights, grad_output.size(-1)),
|
||||||
|
dtype=grad_output.dtype,
|
||||||
|
device=grad_output.device,
|
||||||
|
layout=grad_output.layout)
|
||||||
|
|
||||||
# ============================== Dropout ===========================================
|
# ============================== Dropout ===========================================
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||||
@register_meta(aten.native_dropout.default)
|
@register_meta(aten.native_dropout.default)
|
||||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||||
# notice that mask is bool
|
# notice that mask is bool
|
||||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
@register_meta(aten.native_dropout_backward.default)
|
||||||
@register_meta(aten.native_dropout_backward.default)
|
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
return new_like(grad) # (grad_in)
|
||||||
return new_like(grad) # (grad_in)
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
@ -49,40 +50,45 @@ _DistCommMethod = [
|
||||||
"scatter",
|
"scatter",
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: dive deep here
|
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
# TODO: dive deep here
|
||||||
_AliasATen = [
|
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||||
aten.detach.default,
|
_AliasATen = [
|
||||||
aten.detach_.default,
|
aten.detach.default,
|
||||||
aten.t.default,
|
aten.detach_.default,
|
||||||
aten.transpose.int,
|
aten.t.default,
|
||||||
aten.view.default,
|
aten.transpose.int,
|
||||||
aten._unsafe_view.default,
|
aten.view.default,
|
||||||
aten._reshape_alias.default,
|
aten._unsafe_view.default,
|
||||||
]
|
aten._reshape_alias.default,
|
||||||
|
]
|
||||||
|
|
||||||
_InplaceATen = [
|
_InplaceATen = [
|
||||||
aten.add_.Tensor,
|
aten.add_.Tensor,
|
||||||
aten.add_.Scalar,
|
aten.add_.Scalar,
|
||||||
aten.sub_.Tensor,
|
aten.sub_.Tensor,
|
||||||
aten.sub_.Scalar,
|
aten.sub_.Scalar,
|
||||||
aten.mul_.Tensor,
|
aten.mul_.Tensor,
|
||||||
aten.mul_.Scalar,
|
aten.mul_.Scalar,
|
||||||
aten.div_.Tensor,
|
aten.div_.Tensor,
|
||||||
aten.div_.Scalar,
|
aten.div_.Scalar,
|
||||||
aten.pow_.Tensor,
|
aten.pow_.Tensor,
|
||||||
aten.pow_.Scalar,
|
aten.pow_.Scalar,
|
||||||
]
|
]
|
||||||
|
|
||||||
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
|
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
|
||||||
_MaybeInplaceATen = [
|
_MaybeInplaceATen = [
|
||||||
aten.diagonal.default,
|
aten.diagonal.default,
|
||||||
aten.expand.default,
|
aten.expand.default,
|
||||||
aten.select.int,
|
aten.select.int,
|
||||||
aten.slice.Tensor,
|
aten.slice.Tensor,
|
||||||
aten.split.Tensor,
|
aten.split.Tensor,
|
||||||
aten.squeeze.default,
|
aten.squeeze.default,
|
||||||
aten.permute.default,
|
aten.permute.default,
|
||||||
aten.unsqueeze.default,
|
aten.unsqueeze.default,
|
||||||
aten.as_strided.default,
|
aten.as_strided.default,
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
_AliasATen = []
|
||||||
|
_InplaceATen = []
|
||||||
|
_MaybeInplaceATen = []
|
||||||
|
|
|
@ -11,6 +11,7 @@ from numbers import Number
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from .meta_tensor import MetaTensor
|
from .meta_tensor import MetaTensor
|
||||||
|
@ -403,134 +404,139 @@ def zero_flop_jit(*args):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
flop_mapping = {
|
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||||
|
flop_mapping = {
|
||||||
# gemm
|
# gemm
|
||||||
aten.mm.default: matmul_flop_jit,
|
aten.mm.default: matmul_flop_jit,
|
||||||
aten.matmul.default: matmul_flop_jit,
|
aten.matmul.default: matmul_flop_jit,
|
||||||
aten.addmm.default: addmm_flop_jit,
|
aten.addmm.default: addmm_flop_jit,
|
||||||
aten.bmm.default: bmm_flop_jit,
|
aten.bmm.default: bmm_flop_jit,
|
||||||
|
|
||||||
# convolution
|
# convolution
|
||||||
aten.convolution.default: conv_flop_jit,
|
aten.convolution.default: conv_flop_jit,
|
||||||
aten._convolution.default: conv_flop_jit,
|
aten._convolution.default: conv_flop_jit,
|
||||||
aten.convolution_backward.default: conv_backward_flop_jit,
|
aten.convolution_backward.default: conv_backward_flop_jit,
|
||||||
|
|
||||||
# normalization
|
# normalization
|
||||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||||
|
|
||||||
# pooling
|
# pooling
|
||||||
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
|
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
|
||||||
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
|
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||||
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||||
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
|
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||||
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||||
aten.max_pool1d.default: ewise_flop_counter(1, 0),
|
aten.max_pool1d.default: ewise_flop_counter(1, 0),
|
||||||
aten.max_pool2d.default: ewise_flop_counter(1, 0),
|
aten.max_pool2d.default: ewise_flop_counter(1, 0),
|
||||||
aten.max_pool3d.default: ewise_flop_counter(1, 0),
|
aten.max_pool3d.default: ewise_flop_counter(1, 0),
|
||||||
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
|
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
|
||||||
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
|
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
|
||||||
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||||
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
|
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
|
||||||
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||||
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
|
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||||
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||||
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
|
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||||
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||||
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
|
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
|
||||||
aten.embedding.default: ewise_flop_counter(1, 0),
|
aten.embedding.default: ewise_flop_counter(1, 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
ewise_flop_aten = [
|
ewise_flop_aten = [
|
||||||
# basic op
|
# basic op
|
||||||
aten.add.Tensor,
|
aten.add.Tensor,
|
||||||
aten.add_.Tensor,
|
aten.add_.Tensor,
|
||||||
aten.div.Tensor,
|
aten.div.Tensor,
|
||||||
aten.div_.Tensor,
|
aten.div_.Tensor,
|
||||||
aten.div.Scalar,
|
aten.div.Scalar,
|
||||||
aten.div_.Scalar,
|
aten.div_.Scalar,
|
||||||
aten.mul.Tensor,
|
aten.mul.Tensor,
|
||||||
aten.mul.Scalar,
|
aten.mul.Scalar,
|
||||||
aten.mul_.Tensor,
|
aten.mul_.Tensor,
|
||||||
aten.neg.default,
|
aten.neg.default,
|
||||||
aten.pow.Tensor_Scalar,
|
aten.pow.Tensor_Scalar,
|
||||||
aten.rsub.Scalar,
|
aten.rsub.Scalar,
|
||||||
aten.sum.default,
|
aten.sum.default,
|
||||||
aten.sum.dim_IntList,
|
aten.sum.dim_IntList,
|
||||||
aten.mean.dim,
|
aten.mean.dim,
|
||||||
|
|
||||||
# activation op
|
# activation op
|
||||||
aten.hardswish.default,
|
aten.hardswish.default,
|
||||||
aten.hardswish_.default,
|
aten.hardswish_.default,
|
||||||
aten.hardswish_backward.default,
|
aten.hardswish_backward.default,
|
||||||
aten.hardtanh.default,
|
aten.hardtanh.default,
|
||||||
aten.hardtanh_.default,
|
aten.hardtanh_.default,
|
||||||
aten.hardtanh_backward.default,
|
aten.hardtanh_backward.default,
|
||||||
aten.hardsigmoid_backward.default,
|
aten.hardsigmoid_backward.default,
|
||||||
aten.hardsigmoid.default,
|
aten.hardsigmoid.default,
|
||||||
aten.gelu.default,
|
aten.gelu.default,
|
||||||
aten.gelu_backward.default,
|
aten.gelu_backward.default,
|
||||||
aten.silu.default,
|
aten.silu.default,
|
||||||
aten.silu_.default,
|
aten.silu_.default,
|
||||||
aten.silu_backward.default,
|
aten.silu_backward.default,
|
||||||
aten.sigmoid.default,
|
aten.sigmoid.default,
|
||||||
aten.sigmoid_backward.default,
|
aten.sigmoid_backward.default,
|
||||||
aten._softmax.default,
|
aten._softmax.default,
|
||||||
aten._softmax_backward_data.default,
|
aten._softmax_backward_data.default,
|
||||||
aten.relu_.default,
|
aten.relu_.default,
|
||||||
aten.relu.default,
|
aten.relu.default,
|
||||||
aten.tanh.default,
|
aten.tanh.default,
|
||||||
aten.tanh_backward.default,
|
aten.tanh_backward.default,
|
||||||
aten.threshold_backward.default,
|
aten.threshold_backward.default,
|
||||||
|
|
||||||
# dropout
|
# dropout
|
||||||
aten.native_dropout.default,
|
aten.native_dropout.default,
|
||||||
aten.native_dropout_backward.default,
|
aten.native_dropout_backward.default,
|
||||||
|
|
||||||
# distribution
|
# distribution
|
||||||
aten.bernoulli_.float,
|
aten.bernoulli_.float,
|
||||||
|
|
||||||
# where
|
# where
|
||||||
aten.where.self,
|
aten.where.self,
|
||||||
]
|
]
|
||||||
for op in ewise_flop_aten:
|
for op in ewise_flop_aten:
|
||||||
flop_mapping[op] = ewise_flop_counter(1, 0)
|
flop_mapping[op] = ewise_flop_counter(1, 0)
|
||||||
|
|
||||||
# fix-me: this will be removed in future
|
# fix-me: this will be removed in future
|
||||||
zero_flop_aten = [
|
zero_flop_aten = [
|
||||||
aten.as_strided.default,
|
aten.as_strided.default,
|
||||||
aten.as_strided_.default,
|
aten.as_strided_.default,
|
||||||
aten.cat.default,
|
aten.cat.default,
|
||||||
aten.clone.default,
|
aten.clone.default,
|
||||||
aten.copy_.default,
|
aten.copy_.default,
|
||||||
aten.detach.default,
|
aten.detach.default,
|
||||||
aten.expand.default,
|
aten.expand.default,
|
||||||
aten.empty_like.default,
|
aten.empty_like.default,
|
||||||
aten.new_empty.default,
|
aten.new_empty.default,
|
||||||
aten.new_empty_strided.default,
|
aten.new_empty_strided.default,
|
||||||
aten.ones_like.default,
|
aten.ones_like.default,
|
||||||
aten._reshape_alias.default,
|
aten._reshape_alias.default,
|
||||||
aten.select.int,
|
aten.select.int,
|
||||||
aten.select_backward.default,
|
aten.select_backward.default,
|
||||||
aten.squeeze.dim,
|
aten.squeeze.dim,
|
||||||
aten.slice.Tensor,
|
aten.slice.Tensor,
|
||||||
aten.slice_backward.default,
|
aten.slice_backward.default,
|
||||||
aten.split.Tensor,
|
aten.split.Tensor,
|
||||||
aten.permute.default,
|
aten.permute.default,
|
||||||
aten.t.default,
|
aten.t.default,
|
||||||
aten.transpose.int,
|
aten.transpose.int,
|
||||||
aten._to_copy.default,
|
aten._to_copy.default,
|
||||||
aten.unsqueeze.default,
|
aten.unsqueeze.default,
|
||||||
aten.unbind.int,
|
aten.unbind.int,
|
||||||
aten._unsafe_view.default,
|
aten._unsafe_view.default,
|
||||||
aten.view.default,
|
aten.view.default,
|
||||||
aten.zero_.default,
|
aten.zero_.default,
|
||||||
aten.zeros_like.default,
|
aten.zeros_like.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
for op in zero_flop_aten:
|
for op in zero_flop_aten:
|
||||||
flop_mapping[op] = zero_flop_jit
|
flop_mapping[op] = zero_flop_jit
|
||||||
|
else:
|
||||||
|
flop_mapping = {}
|
||||||
|
elementwise_flop_aten = {}
|
||||||
|
zero_flop_aten = {}
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from .bias_addition import *
|
|
||||||
from .node_util import MetaInfo
|
from .node_util import MetaInfo
|
||||||
from .symbolic_profile import symbolic_profile
|
from .symbolic_profile import symbolic_profile
|
||||||
from .symbolic_trace import symbolic_trace
|
from .tracer.symbolic_trace import symbolic_trace
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
|
import linecache
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
@ -6,11 +9,74 @@ from typing import Any, Dict, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.fx.graph import PythonCode, _PyTreeCodeGen
|
from torch.fx.graph import PythonCode
|
||||||
from torch.fx.graph_module import _exec_with_source, _forward_from_src, _WrappedCall
|
|
||||||
|
try:
|
||||||
|
from torch.fx.graph import _PyTreeCodeGen
|
||||||
|
SUPPORT_PT_CODEGEN = True
|
||||||
|
except ImportError:
|
||||||
|
SUPPORT_PT_CODEGEN = False
|
||||||
|
|
||||||
|
from torch.fx.graph_module import _exec_with_source, _forward_from_src
|
||||||
from torch.nn.modules.module import _addindent
|
from torch.nn.modules.module import _addindent
|
||||||
|
|
||||||
|
|
||||||
|
# This is a copy of torch.fx.graph_module._WrappedCall.
|
||||||
|
# It should be removed when we stop supporting torch < 1.12.0.
|
||||||
|
class _WrappedCall:
|
||||||
|
|
||||||
|
def __init__(self, cls, cls_call):
|
||||||
|
self.cls = cls
|
||||||
|
self.cls_call = cls_call
|
||||||
|
|
||||||
|
# Previously, if an error occurred when valid
|
||||||
|
# symbolically-traced code was run with an invalid input, the
|
||||||
|
# user would see the source of the error as coming from
|
||||||
|
# `File "<eval_with_key_N">`, where N is some number. We use
|
||||||
|
# this function to generate a more informative error message. We
|
||||||
|
# return the traceback itself, a message explaining that the
|
||||||
|
# error occurred in a traced Module's generated forward
|
||||||
|
# function, and five lines of context surrounding the faulty
|
||||||
|
# line
|
||||||
|
@staticmethod
|
||||||
|
def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
|
||||||
|
# auxiliary variables (for readability)
|
||||||
|
err_lineno = frame_summary.lineno
|
||||||
|
assert err_lineno is not None
|
||||||
|
line = frame_summary.line
|
||||||
|
assert line is not None
|
||||||
|
err_line_len = len(line)
|
||||||
|
all_src_lines = linecache.getlines(frame_summary.filename)
|
||||||
|
|
||||||
|
# constituent substrings of the error message
|
||||||
|
tb_repr = traceback.format_exc()
|
||||||
|
custom_msg = ("Call using an FX-traced Module, "
|
||||||
|
f"line {err_lineno} of the traced Module's "
|
||||||
|
"generated forward function:")
|
||||||
|
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
|
||||||
|
marker = "~" * err_line_len + "~~~ <--- HERE"
|
||||||
|
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
|
||||||
|
|
||||||
|
# joined message
|
||||||
|
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
|
||||||
|
|
||||||
|
def __call__(self, obj, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
if self.cls_call is not None:
|
||||||
|
return self.cls_call(obj, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||||
|
except Exception as e:
|
||||||
|
assert e.__traceback__
|
||||||
|
topmost_framesummary: traceback.FrameSummary = \
|
||||||
|
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
|
||||||
|
if "eval_with_key" in topmost_framesummary.filename:
|
||||||
|
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
|
||||||
|
raise e.with_traceback(None)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class ColoGraphModule(torch.fx.GraphModule):
|
class ColoGraphModule(torch.fx.GraphModule):
|
||||||
"""
|
"""
|
||||||
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
|
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
|
||||||
|
@ -65,7 +131,7 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||||
called after editing the contained ``graph``, otherwise the generated
|
called after editing the contained ``graph``, otherwise the generated
|
||||||
code of this ``GraphModule`` will be out of date.
|
code of this ``GraphModule`` will be out of date.
|
||||||
"""
|
"""
|
||||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||||
python_code = self._graph.python_code(root_module='self')
|
python_code = self._graph.python_code(root_module='self')
|
||||||
|
|
|
@ -20,7 +20,7 @@ def union(a, b):
|
||||||
return {**a, **b}
|
return {**a, **b}
|
||||||
|
|
||||||
|
|
||||||
def compute_size_in_bytes(elem: torch.Tensor | Dict | List | Tuple | int) -> int:
|
def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||||
"""Compute the size of a tensor or a collection of tensors in bytes.
|
"""Compute the size of a tensor or a collection of tensors in bytes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -195,8 +195,8 @@ class MetaInfo:
|
||||||
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
||||||
if self.output_size:
|
if self.output_size:
|
||||||
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
||||||
if self.total_size:
|
# if self.total_size:
|
||||||
s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
||||||
if self.temp_size:
|
if self.temp_size:
|
||||||
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
||||||
if self.backward_size:
|
if self.backward_size:
|
||||||
|
|
|
@ -111,7 +111,24 @@ class ShapeProp(torch.fx.Interpreter):
|
||||||
with self.global_hook:
|
with self.global_hook:
|
||||||
r = getattr(self, n.op)(n.target, args, kwargs)
|
r = getattr(self, n.op)(n.target, args, kwargs)
|
||||||
|
|
||||||
unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
def unwrap_fn(elem):
|
||||||
|
|
||||||
|
def _convert_meta(t: torch.Tensor):
|
||||||
|
if t.device == 'meta':
|
||||||
|
return t
|
||||||
|
else:
|
||||||
|
return t.to('meta')
|
||||||
|
|
||||||
|
if isinstance(elem, MetaTensor):
|
||||||
|
return _convert_meta(elem._tensor)
|
||||||
|
|
||||||
|
elif isinstance(elem, torch.Tensor):
|
||||||
|
return _convert_meta(elem)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return elem
|
||||||
|
|
||||||
|
# unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
|
||||||
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
|
||||||
n_info = MetaInfo(n)
|
n_info = MetaInfo(n)
|
||||||
n_info.outputs = _normalize_tuple(r)
|
n_info.outputs = _normalize_tuple(r)
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .bias_addition import *
|
||||||
|
from .custom_leaf_module import *
|
|
@ -4,11 +4,10 @@ graph construction to deal with the compatibility between bias-addition and all-
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.modules.utils import _pair, _single, _triple
|
from torch.nn.modules.utils import _pair, _single, _triple
|
||||||
|
|
||||||
from .symbolic_trace import register_tracer_impl
|
from .tracer import register_tracer_impl
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .tracer import register_leaf_module, register_leaf_module_impl
|
||||||
|
|
||||||
|
try:
|
||||||
|
import apex
|
||||||
|
register_leaf_module(apex.normalization.FusedLayerNorm)
|
||||||
|
register_leaf_module(apex.normalization.FusedRMSNorm)
|
||||||
|
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
|
||||||
|
register_leaf_module(apex.normalization.MixedFusedRMSNorm)
|
||||||
|
|
||||||
|
@register_leaf_module_impl(apex.normalization.FusedLayerNorm)
|
||||||
|
@register_leaf_module_impl(apex.normalization.FusedRMSNorm)
|
||||||
|
@register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm)
|
||||||
|
@register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm)
|
||||||
|
def torch_nn_normalize(self, input: torch.Tensor):
|
||||||
|
# check shape
|
||||||
|
if isinstance(self, torch.nn.BatchNorm1d):
|
||||||
|
assert input.dim() in [2, 3]
|
||||||
|
elif isinstance(self, torch.nn.BatchNorm2d):
|
||||||
|
assert input.dim() == 4
|
||||||
|
elif isinstance(self, torch.nn.BatchNorm3d):
|
||||||
|
assert input.dim() == 5
|
||||||
|
|
||||||
|
# normalization maintain the same shape as the input
|
||||||
|
return input.clone()
|
||||||
|
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
|
@ -0,0 +1,112 @@
|
||||||
|
import operator
|
||||||
|
from typing import Any, Callable, Dict, Optional, Set, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.fx import Graph, Node, Proxy, Tracer
|
||||||
|
from torch.fx.graph import _Namespace
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses import MetaTensor
|
||||||
|
|
||||||
|
Target = Union[Callable[..., Any], str]
|
||||||
|
|
||||||
|
|
||||||
|
class ColoProxy(Proxy):
|
||||||
|
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
|
||||||
|
|
||||||
|
def __init__(self, *args, data=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._meta_data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta_data(self):
|
||||||
|
return self._meta_data
|
||||||
|
|
||||||
|
@meta_data.setter
|
||||||
|
def meta_data(self, args):
|
||||||
|
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||||||
|
self._meta_data = tree_map(wrap_fn, args)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
|
||||||
|
kwargs = {} if kwargs is None else kwargs
|
||||||
|
if orig_method in cls._func_dispatch:
|
||||||
|
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
||||||
|
proxy = impl(*args, **kwargs)
|
||||||
|
cls._func_dispatch[orig_method] = impl
|
||||||
|
return proxy
|
||||||
|
else:
|
||||||
|
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
|
||||||
|
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
||||||
|
if proxy.meta_data is None:
|
||||||
|
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_torch_proxy(cls, proxy: Proxy):
|
||||||
|
return cls(proxy.node, proxy.tracer)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.meta_data)
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return int(self.meta_data)
|
||||||
|
|
||||||
|
def __index__(self):
|
||||||
|
try:
|
||||||
|
return int(self.meta_data)
|
||||||
|
except:
|
||||||
|
return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
|
||||||
|
|
||||||
|
def __float__(self):
|
||||||
|
return float(self.meta_data)
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self.meta_data
|
||||||
|
|
||||||
|
def __getattr__(self, k):
|
||||||
|
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
|
||||||
|
proxy.meta_data = self._meta_data
|
||||||
|
return proxy
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
if self.node.op == "placeholder":
|
||||||
|
# this is used to handle like
|
||||||
|
# if x in kwargs
|
||||||
|
# we don't handle this case for now
|
||||||
|
return False
|
||||||
|
return super().__contains__(key)
|
||||||
|
|
||||||
|
def __isinstancecheck__(self, type):
|
||||||
|
return isinstance(self.meta_data, type)
|
||||||
|
|
||||||
|
|
||||||
|
class ColoAttribute(ColoProxy):
|
||||||
|
|
||||||
|
def __init__(self, root, attr: str, data=None):
|
||||||
|
self.root = root
|
||||||
|
self.attr = attr
|
||||||
|
self.tracer = root.tracer
|
||||||
|
self._meta_data = data
|
||||||
|
self._node: Optional[Node] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node(self):
|
||||||
|
# the node for attributes is added lazily, since most will just be method calls
|
||||||
|
# which do not rely on the getitem call
|
||||||
|
if self._node is None:
|
||||||
|
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||||
|
return self._node
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"ColoAttribute({self.node.name}, attr={self.attr})"
|
|
@ -0,0 +1,157 @@
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx import Tracer
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
from colossalai._analyzer._subclasses import MetaTensor
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ..codegen import ActivationCheckpointCodeGen
|
||||||
|
SUPPORT_ACTIVATION = True
|
||||||
|
except:
|
||||||
|
SUPPORT_ACTIVATION = False
|
||||||
|
from ..graph_module import ColoGraphModule
|
||||||
|
from .tracer import ColoTracer
|
||||||
|
|
||||||
|
|
||||||
|
def _default_device():
|
||||||
|
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def _current_device(module: torch.nn.Module):
|
||||||
|
try:
|
||||||
|
return next(module.parameters()).device
|
||||||
|
except:
|
||||||
|
return _default_device()
|
||||||
|
|
||||||
|
|
||||||
|
def symbolic_trace(
|
||||||
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
|
concrete_args: Optional[Dict[str, Any]] = None,
|
||||||
|
meta_args: Optional[Dict[str, Any]] = None,
|
||||||
|
trace_act_ckpt: bool = False,
|
||||||
|
bias_addition_split: bool = False,
|
||||||
|
) -> ColoGraphModule:
|
||||||
|
"""
|
||||||
|
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
|
||||||
|
attached to the ``Node``s.
|
||||||
|
|
||||||
|
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
|
||||||
|
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
|
||||||
|
|
||||||
|
This tracer is able to trace basic control flow and for loops.
|
||||||
|
|
||||||
|
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
|
||||||
|
(See ./bias_addition.py for more details).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
1. Tracing a ``torch.nn.Module`` with control flow.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.size(0) > 1:
|
||||||
|
x = x.sum(dim=0)
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
|
||||||
|
|
||||||
|
# traced code like:
|
||||||
|
# def forward(self, x):
|
||||||
|
# linear_1 = self.linear(x)
|
||||||
|
# return linear_1
|
||||||
|
|
||||||
|
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
|
||||||
|
|
||||||
|
# traced code like:
|
||||||
|
# def forward(self, x):
|
||||||
|
# sum = x.sum(dim=0); x = None
|
||||||
|
# linear = self.linear(sum); sum = None
|
||||||
|
# return linear
|
||||||
|
|
||||||
|
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
def custom_forward(x):
|
||||||
|
return self.linear(x)
|
||||||
|
return torch.utils.checkpoint.checkpoint(custom_forward, x)
|
||||||
|
|
||||||
|
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
|
||||||
|
|
||||||
|
# traced code like:
|
||||||
|
# def checkpoint_0(self, x):
|
||||||
|
# linear = self.linear(x); x = None
|
||||||
|
# return linear
|
||||||
|
#
|
||||||
|
# def forward(self, x):
|
||||||
|
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
|
||||||
|
# return linear
|
||||||
|
|
||||||
|
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = torch.nn.Linear(2, 2, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
|
||||||
|
|
||||||
|
# traced code like:
|
||||||
|
# def forward(self, x):
|
||||||
|
# linear_bias = self.linear.bias
|
||||||
|
# linear_weight = self.linear.weight
|
||||||
|
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
||||||
|
# add = linear + linear_bias; linear = linear_bias = None
|
||||||
|
# return add
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
|
||||||
|
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
|
||||||
|
Defaults to {}.
|
||||||
|
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
|
||||||
|
for tracing control flow. Defaults to {}.
|
||||||
|
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
|
||||||
|
Defaults to False.
|
||||||
|
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
|
||||||
|
|
||||||
|
Remarks:
|
||||||
|
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
|
||||||
|
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
|
||||||
|
repo. We welcome any feedback and contributions to enhance the extensibility of
|
||||||
|
Colossal-AI.
|
||||||
|
"""
|
||||||
|
if meta_args:
|
||||||
|
device, orig_device = _default_device(), _current_device(root)
|
||||||
|
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
||||||
|
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
||||||
|
bias_addition_split=bias_addition_split).trace(root.to(device),
|
||||||
|
concrete_args=concrete_args,
|
||||||
|
meta_args=tree_map(wrap_fn, meta_args))
|
||||||
|
if trace_act_ckpt and SUPPORT_ACTIVATION:
|
||||||
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
|
root.to(orig_device)
|
||||||
|
else:
|
||||||
|
graph = Tracer().trace(root, concrete_args=concrete_args)
|
||||||
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||||
|
return ColoGraphModule(root, graph, name)
|
|
@ -1,28 +1,19 @@
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import operator
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.fx import Graph, Node, Proxy, Tracer
|
from torch.fx import Graph, Node, Proxy, Tracer
|
||||||
from torch.fx.graph import _Namespace
|
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai._analyzer._subclasses import MetaTensor, _TensorPropertyMethod, _TorchFactoryMethod
|
from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod
|
||||||
|
|
||||||
from .codegen import ActivationCheckpointCodeGen
|
from ..node_util import MetaInfo
|
||||||
from .graph_module import ColoGraphModule
|
from .proxy import ColoProxy
|
||||||
from .node_util import MetaInfo
|
|
||||||
|
|
||||||
Target = Union[Callable[..., Any], str]
|
Target = Union[Callable[..., Any], str]
|
||||||
Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
|
|
||||||
List[Any], # actually Argument
|
|
||||||
Dict[str, Any], # actually Argument
|
|
||||||
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
|
|
||||||
'Node',]]
|
|
||||||
zeros = torch.zeros
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate_suffix(s: str):
|
def _truncate_suffix(s: str):
|
||||||
|
@ -32,17 +23,6 @@ def _truncate_suffix(s: str):
|
||||||
return re.sub(r'_\d+$', '', s)
|
return re.sub(r'_\d+$', '', s)
|
||||||
|
|
||||||
|
|
||||||
def _default_device():
|
|
||||||
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
|
||||||
|
|
||||||
|
|
||||||
def _current_device(module):
|
|
||||||
try:
|
|
||||||
return next(module.parameters()).device
|
|
||||||
except:
|
|
||||||
return _default_device()
|
|
||||||
|
|
||||||
|
|
||||||
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
|
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
|
||||||
|
|
||||||
def wrapper(impl):
|
def wrapper(impl):
|
||||||
|
@ -70,149 +50,6 @@ def register_non_leaf_module(module: nn.Module):
|
||||||
ColoTracer._custom_non_leaf_module.add(module)
|
ColoTracer._custom_non_leaf_module.add(module)
|
||||||
|
|
||||||
|
|
||||||
class ColoProxy(Proxy):
|
|
||||||
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
|
|
||||||
|
|
||||||
def __init__(self, *args, data=None, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._meta_data = data
|
|
||||||
|
|
||||||
@property
|
|
||||||
def meta_data(self):
|
|
||||||
return self._meta_data
|
|
||||||
|
|
||||||
@meta_data.setter
|
|
||||||
def meta_data(self, args):
|
|
||||||
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
|
||||||
self._meta_data = tree_map(wrap_fn, args)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
|
|
||||||
kwargs = {} if kwargs is None else kwargs
|
|
||||||
if orig_method in cls._func_dispatch:
|
|
||||||
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
|
||||||
proxy = impl(*args, **kwargs)
|
|
||||||
cls._func_dispatch[orig_method] = impl
|
|
||||||
return proxy
|
|
||||||
else:
|
|
||||||
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
|
|
||||||
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
|
||||||
if proxy.meta_data is None:
|
|
||||||
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
|
||||||
return proxy
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_torch_proxy(cls, proxy: Proxy):
|
|
||||||
return cls(proxy.node, proxy.tracer)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.meta_data)
|
|
||||||
|
|
||||||
def __int__(self):
|
|
||||||
return int(self.meta_data)
|
|
||||||
|
|
||||||
def __index__(self):
|
|
||||||
try:
|
|
||||||
return int(self.meta_data)
|
|
||||||
except:
|
|
||||||
return zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
|
|
||||||
|
|
||||||
def __float__(self):
|
|
||||||
return float(self.meta_data)
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
return self.meta_data
|
|
||||||
|
|
||||||
def __getattr__(self, k):
|
|
||||||
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
|
|
||||||
proxy.meta_data = self._meta_data
|
|
||||||
return proxy
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
|
||||||
if self.node.op == "placeholder":
|
|
||||||
# this is used to handle like
|
|
||||||
# if x in kwargs
|
|
||||||
# we don't handle this case for now
|
|
||||||
return False
|
|
||||||
return super().__contains__(key)
|
|
||||||
|
|
||||||
def __isinstancecheck__(self, type):
|
|
||||||
return isinstance(self.meta_data, type)
|
|
||||||
|
|
||||||
def size(self, dim=None):
|
|
||||||
if self._meta_data is None:
|
|
||||||
return self._meta_data.size(*[dim] if dim else [])
|
|
||||||
return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
|
|
||||||
|
|
||||||
def dim(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.dim()
|
|
||||||
return self.tracer.create_proxy('call_method', 'dim', (self,), {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.shape
|
|
||||||
return self.tracer.create_proxy('call_function', getattr, (self, 'shape'), {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ndim(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.ndim
|
|
||||||
return self.tracer.create_proxy('call_function', getattr, (self, 'ndim'), {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.device
|
|
||||||
return self.tracer.create_proxy('call_function', getattr, (self, 'device'), {})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
if self._meta_data is not None:
|
|
||||||
return self._meta_data.dtype
|
|
||||||
return self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {})
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs})
|
|
||||||
|
|
||||||
def cpu(self, *args, **kwargs):
|
|
||||||
return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs})
|
|
||||||
|
|
||||||
def cuda(self, *args, **kwargs):
|
|
||||||
return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs})
|
|
||||||
|
|
||||||
|
|
||||||
class ColoAttribute(ColoProxy):
|
|
||||||
|
|
||||||
def __init__(self, root, attr: str, data=None):
|
|
||||||
self.root = root
|
|
||||||
self.attr = attr
|
|
||||||
self.tracer = root.tracer
|
|
||||||
self._meta_data = data
|
|
||||||
self._node: Optional[Node] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node(self):
|
|
||||||
# the node for attributes is added lazily, since most will just be method calls
|
|
||||||
# which do not rely on the getitem call
|
|
||||||
if self._node is None:
|
|
||||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
|
||||||
return self._node
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"ColoAttribute({self.node.name}, attr={self.attr})"
|
|
||||||
|
|
||||||
|
|
||||||
class ColoTracer(Tracer):
|
class ColoTracer(Tracer):
|
||||||
_custom_leaf_module: Set[Type[nn.Module]] = set()
|
_custom_leaf_module: Set[Type[nn.Module]] = set()
|
||||||
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
|
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
|
||||||
|
@ -249,7 +86,6 @@ class ColoTracer(Tracer):
|
||||||
# we will enter the module and split the bias-addition ops
|
# we will enter the module and split the bias-addition ops
|
||||||
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
|
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# user can specify which modules are leaf modules and which are not
|
# user can specify which modules are leaf modules and which are not
|
||||||
return (type(m) not in self._custom_non_leaf_module
|
return (type(m) not in self._custom_non_leaf_module
|
||||||
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
|
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
|
||||||
|
@ -306,9 +142,13 @@ class ColoTracer(Tracer):
|
||||||
mod = self.root.get_submodule(target)
|
mod = self.root.get_submodule(target)
|
||||||
self.disable_module_getattr = True
|
self.disable_module_getattr = True
|
||||||
try:
|
try:
|
||||||
proxy.meta_data = self._custom_leaf_module_impl.get(type(mod),
|
args = tree_map(unwrap_fn, args)
|
||||||
mod.forward)(*tree_map(unwrap_fn, args),
|
kwargs = tree_map(unwrap_fn, kwargs)
|
||||||
**tree_map(unwrap_fn, kwargs))
|
if type(mod) in self._custom_leaf_module:
|
||||||
|
target = self._custom_leaf_module_impl[type(mod)]
|
||||||
|
proxy.meta_data = target(mod, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
proxy.meta_data = mod.forward(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
self.disable_module_getattr = False
|
self.disable_module_getattr = False
|
||||||
return proxy
|
return proxy
|
||||||
|
@ -320,15 +160,21 @@ class ColoTracer(Tracer):
|
||||||
|
|
||||||
def trace(self,
|
def trace(self,
|
||||||
root: torch.nn.Module,
|
root: torch.nn.Module,
|
||||||
concrete_args: Optional[Dict[str, torch.Tensor]] = {},
|
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
meta_args: Optional[Dict[str, torch.Tensor]] = {}) -> Graph:
|
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
|
||||||
|
|
||||||
|
if meta_args is None:
|
||||||
|
meta_args = {}
|
||||||
|
|
||||||
|
if concrete_args is None:
|
||||||
|
concrete_args = {}
|
||||||
|
|
||||||
# check concrete and meta args have valid names
|
# check concrete and meta args have valid names
|
||||||
sig = inspect.signature(root.forward)
|
sig = inspect.signature(root.forward)
|
||||||
sig_names = set(sig.parameters.keys())
|
sig_names = set(sig.parameters.keys())
|
||||||
meta_arg_names = set(meta_args.keys())
|
meta_arg_names = set(meta_args.keys())
|
||||||
concrete_arg_names = set(concrete_args.keys())
|
concrete_arg_names = set(concrete_args.keys())
|
||||||
|
non_concrete_arg_names = sig_names - concrete_arg_names
|
||||||
# update concrete args with default values
|
# update concrete args with default values
|
||||||
for k, v in sig.parameters.items():
|
for k, v in sig.parameters.items():
|
||||||
if k in sig_names - meta_arg_names and \
|
if k in sig_names - meta_arg_names and \
|
||||||
|
@ -352,6 +198,34 @@ class ColoTracer(Tracer):
|
||||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||||
self.mod_dir = ''
|
self.mod_dir = ''
|
||||||
self.graph.lint()
|
self.graph.lint()
|
||||||
|
|
||||||
|
for node in self.graph.nodes:
|
||||||
|
if node.op == "placeholder":
|
||||||
|
# Removing default values for inputs as the forward pass will fail with them.
|
||||||
|
if node.target in non_concrete_arg_names:
|
||||||
|
node.args = ()
|
||||||
|
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
|
||||||
|
# It cannot infer on the attributes and methods the input should have, and fails.
|
||||||
|
node.type = torch.Tensor
|
||||||
|
# It is a concrete arg so it is not used and should be removed.
|
||||||
|
else:
|
||||||
|
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
|
||||||
|
# Newer versions of torch.fx emit an assert statement
|
||||||
|
# for concrete arguments; delete those before we delete
|
||||||
|
# the concrete arg.
|
||||||
|
to_delete = []
|
||||||
|
for user in node.users:
|
||||||
|
if user.target == torch.fx._symbolic_trace._assert_is_none:
|
||||||
|
to_delete.append(user)
|
||||||
|
for user in to_delete:
|
||||||
|
self.graph.erase_node(user)
|
||||||
|
|
||||||
|
self.graph.erase_node(node)
|
||||||
|
|
||||||
|
# TODO: solves GraphModule creation.
|
||||||
|
# Without this, return type annotation "Tuple" is causing code execution failure.
|
||||||
|
if node.op == "output":
|
||||||
|
node.type = None
|
||||||
return self.graph
|
return self.graph
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -454,7 +328,7 @@ class ColoTracer(Tracer):
|
||||||
if node.op == "output":
|
if node.op == "output":
|
||||||
node.type = None
|
node.type = None
|
||||||
self.graph.lint()
|
self.graph.lint()
|
||||||
|
|
||||||
def getattr(self, attr, attr_val, parameter_proxy_cache):
|
def getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||||
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||||
|
|
||||||
|
@ -487,134 +361,3 @@ class ColoTracer(Tracer):
|
||||||
return maybe_parameter_proxy
|
return maybe_parameter_proxy
|
||||||
|
|
||||||
return attr_val
|
return attr_val
|
||||||
|
|
||||||
|
|
||||||
def symbolic_trace(
|
|
||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
|
||||||
concrete_args: Optional[Dict[str, Any]] = {},
|
|
||||||
meta_args: Optional[Dict[str, Any]] = {},
|
|
||||||
trace_act_ckpt: bool = False,
|
|
||||||
bias_addition_split: bool = False,
|
|
||||||
) -> ColoGraphModule:
|
|
||||||
"""
|
|
||||||
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
|
|
||||||
attached to the ``Node``s.
|
|
||||||
|
|
||||||
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
|
|
||||||
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
|
|
||||||
|
|
||||||
This tracer is able to trace basic control flow and for loops.
|
|
||||||
|
|
||||||
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
|
|
||||||
(See ./bias_addition.py for more details).
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
1. Tracing a ``torch.nn.Module`` with control flow.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = torch.nn.Linear(2, 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if x.size(0) > 1:
|
|
||||||
x = x.sum(dim=0)
|
|
||||||
return self.linear(x)
|
|
||||||
|
|
||||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
|
|
||||||
|
|
||||||
# traced code like:
|
|
||||||
# def forward(self, x):
|
|
||||||
# linear_1 = self.linear(x)
|
|
||||||
# return linear_1
|
|
||||||
|
|
||||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
|
|
||||||
|
|
||||||
# traced code like:
|
|
||||||
# def forward(self, x):
|
|
||||||
# sum = x.sum(dim=0); x = None
|
|
||||||
# linear = self.linear(sum); sum = None
|
|
||||||
# return linear
|
|
||||||
|
|
||||||
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = torch.nn.Linear(2, 2)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
def custom_forward(x):
|
|
||||||
return self.linear(x)
|
|
||||||
return torch.utils.checkpoint.checkpoint(custom_forward, x)
|
|
||||||
|
|
||||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
|
|
||||||
|
|
||||||
# traced code like:
|
|
||||||
# def checkpoint_0(self, x):
|
|
||||||
# linear = self.linear(x); x = None
|
|
||||||
# return linear
|
|
||||||
#
|
|
||||||
# def forward(self, x):
|
|
||||||
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
|
|
||||||
# return linear
|
|
||||||
|
|
||||||
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.linear = torch.nn.Linear(2, 2, bias=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.linear(x)
|
|
||||||
|
|
||||||
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
|
|
||||||
|
|
||||||
# traced code like:
|
|
||||||
# def forward(self, x):
|
|
||||||
# linear_bias = self.linear.bias
|
|
||||||
# linear_weight = self.linear.weight
|
|
||||||
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
|
|
||||||
# add = linear + linear_bias; linear = linear_bias = None
|
|
||||||
# return add
|
|
||||||
|
|
||||||
Args:
|
|
||||||
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
|
|
||||||
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
|
|
||||||
Defaults to {}.
|
|
||||||
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
|
|
||||||
for tracing control flow. Defaults to {}.
|
|
||||||
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
|
|
||||||
Defaults to False.
|
|
||||||
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
|
|
||||||
|
|
||||||
Remarks:
|
|
||||||
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
|
|
||||||
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
|
|
||||||
repo. We welcome any feedback and contributions to enhance the extensibility of
|
|
||||||
Colossal-AI.
|
|
||||||
"""
|
|
||||||
if meta_args:
|
|
||||||
device, orig_device = _default_device(), _current_device(root)
|
|
||||||
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
|
||||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
|
||||||
bias_addition_split=bias_addition_split).trace(root.to(device),
|
|
||||||
concrete_args=concrete_args,
|
|
||||||
meta_args=tree_map(wrap_fn, meta_args))
|
|
||||||
if trace_act_ckpt:
|
|
||||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
|
||||||
root.to(orig_device)
|
|
||||||
else:
|
|
||||||
graph = Tracer().trace(root, concrete_args=concrete_args)
|
|
||||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
|
||||||
return ColoGraphModule(root, graph, name)
|
|
|
@ -1,5 +1,4 @@
|
||||||
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
|
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
|
||||||
|
|
||||||
from .registry import model_zoo
|
from .registry import model_zoo
|
||||||
|
|
||||||
__all__ = ['model_zoo']
|
__all__ = ['model_zoo']
|
||||||
|
|
|
@ -17,6 +17,14 @@ def data_gen():
|
||||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def seq_classification_data_gen():
|
||||||
|
# batch sizes should be 1 if no padding token is defined.
|
||||||
|
input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||||
|
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
output_transform_fn = lambda x: x
|
output_transform_fn = lambda x: x
|
||||||
|
|
||||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
||||||
|
@ -44,6 +52,6 @@ model_zoo.register(name='transformers_gpt_for_token_classification',
|
||||||
model_attribute=ModelAttribute(has_control_flow=True))
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
||||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
|
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
|
||||||
data_gen_fn=data_gen,
|
data_gen_fn=seq_classification_data_gen,
|
||||||
output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
model_attribute=ModelAttribute(has_control_flow=True))
|
model_attribute=ModelAttribute(has_control_flow=True))
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -73,7 +74,7 @@ class AddmmModel(torch.nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.parametrize("bias", [True, False])
|
@pytest.mark.parametrize("bias", [True, False])
|
||||||
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
@pytest.mark.parametrize("bias_addition_split", [True, False])
|
||||||
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
|
||||||
|
|
|
@ -3,7 +3,8 @@ from numpy import isin
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch.utils._pytree import tree_flatten
|
from torch.utils._pytree import tree_flatten
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
# from colossalai.fx import symbolic_trace
|
||||||
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
def trace_model_and_compare_output(model, data_gen):
|
def trace_model_and_compare_output(model, data_gen):
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
@ -6,6 +9,7 @@ BATCH_SIZE = 2
|
||||||
SEQ_LENGTH = 16
|
SEQ_LENGTH = 16
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_albert():
|
def test_albert():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_albert')
|
sub_registry = model_zoo.get_sub_registry('transformers_albert')
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_bert():
|
def test_bert():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_bert')
|
sub_registry = model_zoo.get_sub_registry('transformers_bert')
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,24 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
# TODO: remove this skip once we handle the latest gpt model
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
@pytest.mark.skip
|
|
||||||
def test_gpt():
|
def test_gpt():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
|
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
|
|
||||||
|
# TODO: support the following models
|
||||||
|
# 1. GPT2DoubleHeadsModel
|
||||||
|
# as they are not supported, let's skip them
|
||||||
|
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
|
||||||
|
continue
|
||||||
|
|
||||||
trace_model_and_compare_output(model, data_gen_fn)
|
trace_model_and_compare_output(model, data_gen_fn)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_opt():
|
def test_opt():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
from hf_tracer_utils import trace_model_and_compare_output
|
from hf_tracer_utils import trace_model_and_compare_output
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_t5():
|
def test_t5():
|
||||||
sub_registry = model_zoo.get_sub_registry('transformers_t5')
|
sub_registry = model_zoo.get_sub_registry('transformers_t5')
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
import timm.models as tm
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
|
||||||
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_timm_models():
|
def test_timm_models():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,18 @@
|
||||||
import re
|
import pytest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torchaudio_utils import trace_and_compare
|
from torchaudio_utils import trace_and_compare
|
||||||
|
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||||
def test_torchaudio_models():
|
def test_torchaudio_models():
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||||
# FIXME(ver217): temporarily skip these models
|
|
||||||
if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name):
|
|
||||||
continue
|
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
trace_and_compare(model,
|
trace_and_compare(model,
|
||||||
data_gen_fn,
|
data_gen_fn,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
|
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
BATCH = 2
|
BATCH = 2
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.fx import symbolic_trace
|
from colossalai._analyzer.fx import symbolic_trace
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue