mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
2.0 KiB
51 lines
2.0 KiB
import torch |
|
from torch.nn import functional as F |
|
|
|
from colossalai.fx.tracer.meta_patch import patched_function |
|
from colossalai.testing import clear_cache_before_run |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_conv(): |
|
# test F.conv_1d |
|
data_1d = torch.rand(3, 16, 10) |
|
weight_1d = torch.rand(3, 16, 3) |
|
out_1d = F.conv1d(data_1d, weight_1d) |
|
patched_out_1d = patched_function.torch_nn_functional_conv1d(data_1d, weight_1d) |
|
assert out_1d.shape == patched_out_1d.shape |
|
|
|
# test F.conv_transpose1d |
|
weight_1d = torch.transpose(weight_1d, 0, 1) |
|
out_transpose_1d = F.conv_transpose1d(data_1d, weight_1d) |
|
patched_out_transpose_1d = patched_function.torch_nn_functional_convtranspose1d(data_1d, weight_1d) |
|
assert out_transpose_1d.shape == patched_out_transpose_1d.shape |
|
|
|
# test F.conv2d |
|
data_2d = torch.rand(3, 16, 10, 10) |
|
weight_2d = torch.rand(3, 16, 3, 3) |
|
out_2d = F.conv2d(data_2d, weight_2d) |
|
patched_out_2d = patched_function.torch_nn_functional_conv2d(data_2d, weight_2d) |
|
assert out_2d.shape == patched_out_2d.shape |
|
|
|
# test F.conv_transpose2d |
|
weight_2d = torch.transpose(weight_2d, 0, 1) |
|
out_transpose_2d = F.conv_transpose2d(data_2d, weight_2d) |
|
patched_out_transpose_2d = patched_function.torch_nn_functional_convtranspose2d(data_2d, weight_2d) |
|
assert out_transpose_2d.shape == patched_out_transpose_2d.shape |
|
|
|
# test F.conv3d |
|
data_3d = torch.rand(3, 16, 10, 10, 10) |
|
weight_3d = torch.rand(3, 16, 3, 3, 3) |
|
out_3d = F.conv3d(data_3d, weight_3d) |
|
patched_out_3d = patched_function.torch_nn_functional_conv3d(data_3d, weight_3d) |
|
assert out_3d.shape == patched_out_3d.shape |
|
|
|
# test F.conv_transpose3d |
|
weight_3d = torch.transpose(weight_3d, 0, 1) |
|
out_transpose_3d = F.conv_transpose3d(data_3d, weight_3d) |
|
patched_out_transpose_3d = patched_function.torch_nn_functional_convtranspose3d(data_3d, weight_3d) |
|
assert out_transpose_3d.shape == patched_out_transpose_3d.shape |
|
|
|
|
|
if __name__ == "__main__": |
|
test_conv()
|
|
|