mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
126 lines
2.7 KiB
126 lines
2.7 KiB
import operator
|
|
|
|
import torch
|
|
|
|
__all__ = [
|
|
"ELEMENTWISE_MODULE_OP",
|
|
"ELEMENTWISE_FUNC_OP",
|
|
"RESHAPE_FUNC_OP",
|
|
"CONV_MODULE_OP",
|
|
"CONV_FUNC_OP",
|
|
"LINEAR_MODULE_OP",
|
|
"LINEAR_FUNC_OP",
|
|
"BATCHNORM_MODULE_OP",
|
|
"POOL_MODULE_OP",
|
|
"NON_PARAM_FUNC_OP",
|
|
"BCAST_FUNC_OP",
|
|
"EMBEDDING_MODULE_OP",
|
|
"LAYERNORM_MODULE_OP",
|
|
"ELEMENTWISE_METHOD_OP",
|
|
"RESHAPE_METHOD_OP",
|
|
"INFINITY_COST",
|
|
]
|
|
|
|
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
|
|
ELEMENTWISE_FUNC_OP = [
|
|
torch.abs,
|
|
torch.cos,
|
|
torch.exp,
|
|
operator.neg,
|
|
torch.multiply,
|
|
torch.nn.functional.relu,
|
|
torch.nn.functional.dropout,
|
|
# softmax should not be here
|
|
torch.nn.functional.softmax,
|
|
]
|
|
ELEMENTWISE_METHOD_OP = [
|
|
torch.Tensor.to,
|
|
torch.Tensor.type,
|
|
# TODO: contiguous maybe need some extra processes.
|
|
torch.Tensor.contiguous,
|
|
]
|
|
RESHAPE_FUNC_OP = [
|
|
torch.flatten,
|
|
torch.reshape,
|
|
torch.transpose,
|
|
torch.split,
|
|
torch.permute,
|
|
operator.getitem,
|
|
]
|
|
RESHAPE_METHOD_OP = [
|
|
torch.Tensor.view,
|
|
torch.Tensor.unsqueeze,
|
|
torch.Tensor.split,
|
|
torch.Tensor.permute,
|
|
torch.Tensor.transpose,
|
|
]
|
|
BCAST_FUNC_OP = [
|
|
torch.add,
|
|
torch.sub,
|
|
torch.mul,
|
|
torch.div,
|
|
torch.floor_divide,
|
|
torch.true_divide,
|
|
operator.add,
|
|
operator.sub,
|
|
operator.mul,
|
|
operator.floordiv,
|
|
operator.truediv,
|
|
torch.matmul,
|
|
operator.pow,
|
|
torch.pow,
|
|
]
|
|
CONV_MODULE_OP = [
|
|
torch.nn.Conv1d,
|
|
torch.nn.Conv2d,
|
|
torch.nn.Conv3d,
|
|
torch.nn.ConvTranspose1d,
|
|
torch.nn.ConvTranspose2d,
|
|
torch.nn.ConvTranspose3d,
|
|
]
|
|
CONV_FUNC_OP = [
|
|
torch.conv1d,
|
|
torch.conv2d,
|
|
torch.conv3d,
|
|
torch.conv_transpose1d,
|
|
torch.conv_transpose2d,
|
|
torch.conv_transpose3d,
|
|
]
|
|
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
|
|
LINEAR_MODULE_OP = [torch.nn.Linear]
|
|
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
|
|
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
|
|
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
|
|
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
|
|
NON_PARAM_FUNC_OP = [
|
|
torch.flatten,
|
|
torch.reshape,
|
|
torch.abs,
|
|
torch.cos,
|
|
torch.exp,
|
|
operator.neg,
|
|
torch.multiply,
|
|
torch.nn.functional.relu,
|
|
torch.nn.functional.dropout,
|
|
torch.flatten,
|
|
torch.where,
|
|
operator.pow,
|
|
torch.pow,
|
|
torch.tanh,
|
|
torch.add,
|
|
torch.sub,
|
|
torch.mul,
|
|
torch.div,
|
|
torch.floor_divide,
|
|
torch.true_divide,
|
|
operator.add,
|
|
operator.sub,
|
|
operator.mul,
|
|
operator.floordiv,
|
|
operator.truediv,
|
|
# softmax should not be here
|
|
torch.nn.functional.softmax,
|
|
]
|
|
|
|
INFINITY_COST = 1e13
|