mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
533 lines
17 KiB
533 lines
17 KiB
import torch |
|
|
|
from colossalai.fx.tracer.meta_patch import patched_module |
|
from colossalai.testing import clear_cache_before_run |
|
|
|
|
|
def _run(data, module, patch_fn): |
|
try: |
|
if isinstance(data, dict): |
|
output = patch_fn(module, **data) |
|
if isinstance(data, tuple) or isinstance(data, list): |
|
output = patch_fn(module, *data) |
|
else: |
|
output = patch_fn(module, data) |
|
return output |
|
except Exception as e: |
|
return e |
|
|
|
|
|
def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape): |
|
output = _run(data, module, patch_fn) |
|
|
|
if expect_exception: |
|
assert isinstance(output, AssertionError) |
|
else: |
|
assert not isinstance(output, Exception) |
|
if isinstance(output, tuple): |
|
for item, shape in zip(output, output_shape): |
|
assert item.is_meta |
|
assert item.shape == shape |
|
else: |
|
assert output.is_meta |
|
assert output.shape == output_shape |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_linear(): |
|
# test linear patch can produce the meta output with correct shape |
|
data = torch.rand(2, 4, device="meta") |
|
module = torch.nn.Linear(4, 2) |
|
_assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2])) |
|
|
|
# test if the linear patch can catch exception when dimension does not match |
|
data = torch.rand(2, 2, device="meta") |
|
_assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_rnn(): |
|
# test rnn patch can produce the meta output with correct shape |
|
data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) |
|
module = torch.nn.RNN(10, 20, 2) |
|
output, hn = module(*data) |
|
meta_data = (torch.randn(5, 3, 10).to("meta"), torch.randn(2, 3, 20).to("meta")) |
|
_assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape)) |
|
|
|
# test if the rnn patch can catch exception when dimension does not match |
|
data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) |
|
module = torch.nn.RNN(10, 20, 2) |
|
output, hn = module(*data) |
|
meta_data = (torch.randn(5, 3, 1).to("meta"), torch.randn(2, 3, 20).to("meta")) |
|
_assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_embedding(): |
|
data = torch.rand(2, 4, device="meta") |
|
|
|
# test layernorm |
|
ln = torch.nn.LayerNorm(4) |
|
_assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape) |
|
|
|
# test group norm |
|
gn = torch.nn.GroupNorm(4, num_channels=8) |
|
_assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape) |
|
|
|
# test batch norm 1d |
|
bn1d = torch.nn.BatchNorm1d(4) |
|
data = torch.rand(2, 4, device="meta") |
|
_assert_output_shape( |
|
data=data, |
|
module=bn1d, |
|
patch_fn=patched_module.torch_nn_normalize, |
|
expect_exception=False, |
|
output_shape=data.shape, |
|
) |
|
|
|
data = torch.rand(2, 4, device="meta") |
|
_assert_output_shape( |
|
data=data, |
|
module=bn1d, |
|
patch_fn=patched_module.torch_nn_normalize, |
|
expect_exception=False, |
|
output_shape=data.shape, |
|
) |
|
|
|
data = torch.rand(2, 3, 4, device="meta") |
|
_assert_output_shape( |
|
data=data, |
|
module=bn1d, |
|
patch_fn=patched_module.torch_nn_normalize, |
|
expect_exception=False, |
|
output_shape=data.shape, |
|
) |
|
|
|
data = torch.rand(1, 2, 3, 4, device="meta") |
|
_assert_output_shape( |
|
data=data, module=bn1d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None |
|
) |
|
|
|
# test batch norm 2d |
|
bn2d = torch.nn.BatchNorm2d(4) |
|
|
|
data = torch.rand(1, 2, 3, 4, device="meta") |
|
_assert_output_shape( |
|
data=data, |
|
module=bn2d, |
|
patch_fn=patched_module.torch_nn_normalize, |
|
expect_exception=False, |
|
output_shape=data.shape, |
|
) |
|
|
|
data = torch.rand(2, 3, 4, device="meta") |
|
_assert_output_shape( |
|
data=data, module=bn2d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None |
|
) |
|
|
|
# # test batch size 3d |
|
bn3d = torch.nn.BatchNorm3d(4) |
|
|
|
data = torch.rand(1, 1, 2, 3, 4, device="meta") |
|
_assert_output_shape( |
|
data=data, |
|
module=bn3d, |
|
patch_fn=patched_module.torch_nn_normalize, |
|
expect_exception=False, |
|
output_shape=data.shape, |
|
) |
|
|
|
data = torch.rand(1, 2, 3, 4, device="meta") |
|
_assert_output_shape( |
|
data=data, module=bn3d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None |
|
) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_conv1d(): |
|
# test conv 1d |
|
data = torch.rand(2, 3, 4) |
|
|
|
conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2) |
|
materialized_output = conv1d(data) |
|
meta_data = data.to("meta") |
|
_assert_output_shape( |
|
data=meta_data, |
|
module=conv1d, |
|
patch_fn=patched_module.torch_nn_conv1d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) |
|
materialized_output = conv1d(data) |
|
meta_data = data.to("meta") |
|
_assert_output_shape( |
|
data=meta_data, |
|
module=conv1d, |
|
patch_fn=patched_module.torch_nn_conv1d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
conv1d = torch.nn.Conv1d( |
|
in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" |
|
) |
|
materialized_output = conv1d(data) |
|
meta_data = data.to("meta") |
|
_assert_output_shape( |
|
data=meta_data, |
|
module=conv1d, |
|
patch_fn=patched_module.torch_nn_conv1d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
|
|
def test_conv2d(): |
|
# test conv 2d |
|
data = torch.rand(2, 3, 4, 4) |
|
conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2) |
|
materialized_output = conv2d(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=conv2d, |
|
patch_fn=patched_module.torch_nn_conv2d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) |
|
materialized_output = conv2d(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=conv2d, |
|
patch_fn=patched_module.torch_nn_conv2d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) |
|
materialized_output = conv2d(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=conv2d, |
|
patch_fn=patched_module.torch_nn_conv2d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
conv2d = torch.nn.Conv2d( |
|
in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" |
|
) |
|
materialized_output = conv2d(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=conv2d, |
|
patch_fn=patched_module.torch_nn_conv2d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_conv3d(): |
|
# test conv 3d |
|
data = torch.rand(2, 3, 4, 4, 4) |
|
conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2) |
|
materialized_output = conv3d(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=conv3d, |
|
patch_fn=patched_module.torch_nn_conv3d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) |
|
materialized_output = conv3d(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=conv3d, |
|
patch_fn=patched_module.torch_nn_conv3d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) |
|
materialized_output = conv3d(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=conv3d, |
|
patch_fn=patched_module.torch_nn_conv3d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
conv3d = torch.nn.Conv3d( |
|
in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" |
|
) |
|
materialized_output = conv3d(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=conv3d, |
|
patch_fn=patched_module.torch_nn_conv3d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_conv_transpose1d(): |
|
# test conv transpose1d |
|
data = torch.rand(2, 3, 4) |
|
|
|
convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2) |
|
materialized_output = convtrans1d(data) |
|
meta_data = data.to("meta") |
|
_assert_output_shape( |
|
data=meta_data, |
|
module=convtrans1d, |
|
patch_fn=patched_module.torch_nn_convtranspose1d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) |
|
materialized_output = convtrans1d(data) |
|
meta_data = data.to("meta") |
|
_assert_output_shape( |
|
data=meta_data, |
|
module=convtrans1d, |
|
patch_fn=patched_module.torch_nn_convtranspose1d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_conv_transpose2d(): |
|
# test conv transpose2d |
|
data = torch.rand(2, 3, 4, 4) |
|
|
|
convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2) |
|
materialized_output = convtrans2d(data) |
|
meta_data = data.to("meta") |
|
_assert_output_shape( |
|
data=meta_data, |
|
module=convtrans2d, |
|
patch_fn=patched_module.torch_nn_convtranspose2d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) |
|
materialized_output = convtrans2d(data) |
|
meta_data = data.to("meta") |
|
_assert_output_shape( |
|
data=meta_data, |
|
module=convtrans2d, |
|
patch_fn=patched_module.torch_nn_convtranspose2d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_conv_transpose3d(): |
|
# test conv transpose2d |
|
data = torch.rand(2, 3, 4, 4, 4) |
|
|
|
convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2) |
|
materialized_output = convtrans3d(data) |
|
meta_data = data.to("meta") |
|
_assert_output_shape( |
|
data=meta_data, |
|
module=convtrans3d, |
|
patch_fn=patched_module.torch_nn_convtranspose3d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) |
|
materialized_output = convtrans3d(data) |
|
meta_data = data.to("meta") |
|
_assert_output_shape( |
|
data=meta_data, |
|
module=convtrans3d, |
|
patch_fn=patched_module.torch_nn_convtranspose3d, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_pool1d(): |
|
combinations = [ |
|
[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], |
|
[torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d], |
|
] |
|
|
|
for layer_cls, patch_func in combinations: |
|
pooler = layer_cls(kernel_size=3) |
|
|
|
data = torch.rand(2, 3, 4) |
|
materialized_output = pooler(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=pooler, |
|
patch_fn=patch_func, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
data = torch.rand(2, 4) |
|
materialized_output = pooler(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=pooler, |
|
patch_fn=patch_func, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
data = torch.rand(2, 3, 4, 4) |
|
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_pool2d(): |
|
combinations = [ |
|
[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], |
|
[torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d], |
|
] |
|
|
|
for layer_cls, patch_func in combinations: |
|
pooler = layer_cls(kernel_size=3) |
|
|
|
# test max pool 3d |
|
data = torch.rand(2, 3, 4, 4) |
|
materialized_output = pooler(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=pooler, |
|
patch_fn=patch_func, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
# test max pool 3d |
|
data = torch.rand(2, 4, 4) |
|
materialized_output = pooler(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=pooler, |
|
patch_fn=patch_func, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
# test max pool 3d |
|
data = torch.rand(2, 3, 4, 4, 4) |
|
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_pool3d(): |
|
combinations = [ |
|
[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], |
|
[torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d], |
|
] |
|
|
|
for layer_cls, patch_func in combinations: |
|
pooler = layer_cls(kernel_size=3) |
|
|
|
# test max pool 3d |
|
data = torch.rand(2, 3, 4, 4, 4) |
|
materialized_output = pooler(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=pooler, |
|
patch_fn=patch_func, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
# test max pool 3d |
|
data = torch.rand(2, 4, 4, 4) |
|
materialized_output = pooler(data) |
|
_assert_output_shape( |
|
data=data, |
|
module=pooler, |
|
patch_fn=patch_func, |
|
expect_exception=False, |
|
output_shape=materialized_output.shape, |
|
) |
|
|
|
# test max pool 3d |
|
data = torch.rand(2, 3, 4) |
|
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) |
|
|
|
|
|
# adapative pooling is different from other pooling, so test it individually |
|
@clear_cache_before_run() |
|
def test_adaptive_pooling_1d(): |
|
pooler = torch.nn.AdaptiveAvgPool1d(output_size=3) |
|
patch_func = patched_module.torch_nn_adapative_pooling_1d |
|
|
|
data = torch.rand(3, 4) |
|
output = pooler(data) |
|
_assert_output_shape( |
|
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape |
|
) |
|
|
|
data = torch.rand(2, 3, 4) |
|
output = pooler(data) |
|
_assert_output_shape( |
|
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape |
|
) |
|
|
|
data = torch.rand(2, 3, 4, 5) |
|
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_adaptive_pooling_2d(): |
|
pooler = torch.nn.AdaptiveAvgPool2d(output_size=3) |
|
patch_func = patched_module.torch_nn_adapative_pooling_2d |
|
|
|
data = torch.rand(3, 4) |
|
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) |
|
|
|
data = torch.rand(2, 3, 4) |
|
output = pooler(data) |
|
_assert_output_shape( |
|
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape |
|
) |
|
|
|
data = torch.rand(2, 3, 4, 5) |
|
output = pooler(data) |
|
_assert_output_shape( |
|
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape |
|
) |
|
|
|
|
|
@clear_cache_before_run() |
|
def test_adaptive_pooling_3d(): |
|
pooler = torch.nn.AdaptiveAvgPool3d(output_size=3) |
|
patch_func = patched_module.torch_nn_adapative_pooling_3d |
|
|
|
data = torch.rand(3, 4, 5) |
|
_assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) |
|
|
|
data = torch.rand(2, 3, 4, 5) |
|
output = pooler(data) |
|
_assert_output_shape( |
|
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape |
|
) |
|
|
|
data = torch.rand(2, 3, 4, 5, 6) |
|
output = pooler(data) |
|
_assert_output_shape( |
|
data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape |
|
)
|
|
|