You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_tensor/test_padded_tensor.py

47 lines
1.7 KiB

import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_padded_tensor(rank, world_size, port):
disable_existing_loggers()
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
original_tensor = torch.rand(32, 64).to("cuda")
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0)
assert padded_tensor.dist_layout == d_tensor.dist_layout
tensor_copy = padded_tensor.clone()
assert is_padded_tensor(tensor_copy)
assert is_distributed_tensor(tensor_copy)
tensor_detached = padded_tensor.detach()
assert is_padded_tensor(tensor_detached)
assert is_distributed_tensor(tensor_detached)
unpadded_tensor = to_unpadded_tensor(padded_tensor)
assert unpadded_tensor.shape == d_tensor.shape
assert is_distributed_tensor(unpadded_tensor)
global_tensor = to_global(unpadded_tensor)
assert global_tensor.shape == original_tensor.shape
@rerun_if_address_is_in_use()
def test_padded_tensor():
world_size = 4
spawn(check_padded_tensor, world_size)
if __name__ == "__main__":
test_padded_tensor()