ColossalAI/colossalai/auto_parallel/tensor_shard/constants.py

126 lines
2.7 KiB
Python
Raw Normal View History

import operator
import torch
__all__ = [
"ELEMENTWISE_MODULE_OP",
"ELEMENTWISE_FUNC_OP",
"RESHAPE_FUNC_OP",
"CONV_MODULE_OP",
"CONV_FUNC_OP",
"LINEAR_MODULE_OP",
"LINEAR_FUNC_OP",
"BATCHNORM_MODULE_OP",
"POOL_MODULE_OP",
"NON_PARAM_FUNC_OP",
"BCAST_FUNC_OP",
"EMBEDDING_MODULE_OP",
"LAYERNORM_MODULE_OP",
"ELEMENTWISE_METHOD_OP",
"RESHAPE_METHOD_OP",
"INFINITY_COST",
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
# softmax should not be here
torch.nn.functional.softmax,
]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
# TODO: contiguous maybe need some extra processes.
torch.Tensor.contiguous,
]
RESHAPE_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.transpose,
torch.split,
torch.permute,
operator.getitem,
]
RESHAPE_METHOD_OP = [
torch.Tensor.view,
torch.Tensor.unsqueeze,
torch.Tensor.split,
torch.Tensor.permute,
torch.Tensor.transpose,
]
BCAST_FUNC_OP = [
torch.add,
torch.sub,
torch.mul,
torch.div,
torch.floor_divide,
torch.true_divide,
operator.add,
operator.sub,
operator.mul,
operator.floordiv,
operator.truediv,
torch.matmul,
operator.pow,
torch.pow,
]
CONV_MODULE_OP = [
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
]
CONV_FUNC_OP = [
torch.conv1d,
torch.conv2d,
torch.conv3d,
torch.conv_transpose1d,
torch.conv_transpose2d,
torch.conv_transpose3d,
]
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
NON_PARAM_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
torch.where,
operator.pow,
torch.pow,
torch.tanh,
torch.add,
torch.sub,
torch.mul,
torch.div,
torch.floor_divide,
torch.true_divide,
operator.add,
operator.sub,
operator.mul,
operator.floordiv,
operator.truediv,
# softmax should not be here
torch.nn.functional.softmax,
]
INFINITY_COST = 1e13