mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.8 KiB
47 lines
1.8 KiB
7 months ago
|
import torch
|
||
|
|
||
|
from colossalai.device.device_mesh import DeviceMesh
|
||
|
from colossalai.initialize import launch
|
||
|
from colossalai.logging import disable_existing_loggers
|
||
|
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global
|
||
|
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
|
||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||
|
|
||
|
|
||
|
def check_padded_tensor(rank, world_size, port):
|
||
|
disable_existing_loggers()
|
||
|
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||
|
original_tensor = torch.rand(32, 64).to("cuda")
|
||
|
|
||
|
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||
|
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
||
|
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
|
||
|
|
||
|
padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0)
|
||
|
assert padded_tensor.dist_layout == d_tensor.dist_layout
|
||
|
|
||
|
tensor_copy = padded_tensor.clone()
|
||
|
assert is_padded_tensor(tensor_copy)
|
||
|
assert is_distributed_tensor(tensor_copy)
|
||
|
|
||
|
tensor_detached = padded_tensor.detach()
|
||
|
assert is_padded_tensor(tensor_detached)
|
||
|
assert is_distributed_tensor(tensor_detached)
|
||
|
|
||
|
unpadded_tensor = to_unpadded_tensor(padded_tensor)
|
||
|
assert unpadded_tensor.shape == d_tensor.shape
|
||
|
assert is_distributed_tensor(unpadded_tensor)
|
||
|
|
||
|
global_tensor = to_global(unpadded_tensor)
|
||
|
assert global_tensor.shape == original_tensor.shape
|
||
|
|
||
|
|
||
|
@rerun_if_address_is_in_use()
|
||
|
def test_padded_tensor():
|
||
|
world_size = 4
|
||
|
spawn(check_padded_tensor, world_size)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test_padded_tensor()
|