mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
5.7 KiB
138 lines
5.7 KiB
2 years ago
|
import torch
|
||
|
from colossalai.utils.checkpoint_io.meta import ParamRedistMeta
|
||
|
from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param
|
||
|
|
||
|
|
||
|
def test_flatten_zero_param_even() -> None:
|
||
|
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12])
|
||
|
orig_tensor = torch.rand(4, 4)
|
||
|
tensors = list(orig_tensor.reshape(-1).chunk(4))
|
||
|
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
|
||
|
assert len(tensors) == len(flat_tensors)
|
||
|
for t, st in zip(tensors, flat_tensors):
|
||
|
assert torch.equal(t, st)
|
||
|
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||
|
assert len(unmerged_tensors) == 1
|
||
|
unmerged_tensors = unmerged_tensors[0]
|
||
|
assert len(tensors) == len(unmerged_tensors)
|
||
|
for t, tl in zip(tensors, unmerged_tensors):
|
||
|
assert torch.equal(t, tl)
|
||
|
|
||
|
|
||
|
def test_flatten_zero_param_uneven() -> None:
|
||
|
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13])
|
||
|
orig_tensor = torch.rand(4, 4)
|
||
|
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
|
||
|
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
|
||
|
assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0
|
||
|
flat_tensors = flat_tensors[1:-1]
|
||
|
assert len(tensors) == len(flat_tensors)
|
||
|
for t, st in zip(tensors, flat_tensors):
|
||
|
assert torch.equal(t, st)
|
||
|
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||
|
assert len(unmerged_tensors) == 1
|
||
|
unmerged_tensors = unmerged_tensors[0]
|
||
|
assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0
|
||
|
unmerged_tensors = unmerged_tensors[1:-1]
|
||
|
assert len(tensors) == len(unmerged_tensors)
|
||
|
for t, tl in zip(tensors, unmerged_tensors):
|
||
|
assert torch.equal(t, tl)
|
||
|
|
||
|
|
||
|
def test_split_tp_param_1d_row() -> None:
|
||
|
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4])
|
||
|
orig_tensor = torch.rand(4, 4)
|
||
|
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
|
||
|
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||
|
assert len(tensors) == len(split_tensors)
|
||
|
for t, st in zip(tensors, split_tensors):
|
||
|
assert torch.equal(t, st)
|
||
|
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||
|
assert len(tensors) == len(unmerged_tensors)
|
||
|
for t, tl in zip(tensors, unmerged_tensors):
|
||
|
assert len(tl) == 1
|
||
|
assert torch.equal(t, tl[0])
|
||
|
|
||
|
|
||
|
def test_split_tp_param_1d_col() -> None:
|
||
|
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4])
|
||
|
orig_tensor = torch.rand(4, 4)
|
||
|
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
|
||
|
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||
|
assert len(tensors) == len(split_tensors)
|
||
|
for t, st in zip(tensors, split_tensors):
|
||
|
assert torch.equal(t, st)
|
||
|
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||
|
assert len(tensors) == len(unmerged_tensors)
|
||
|
for t, tl in zip(tensors, unmerged_tensors):
|
||
|
assert len(tl) == 1
|
||
|
assert torch.equal(t, tl[0])
|
||
|
|
||
|
|
||
|
def test_split_tp_param_2d() -> None:
|
||
|
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3])
|
||
|
orig_tensor = torch.rand(4, 6)
|
||
|
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||
|
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||
|
assert len(tensors) == len(split_tensors)
|
||
|
for t, st in zip(tensors, split_tensors):
|
||
|
assert torch.equal(t, st)
|
||
|
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||
|
assert len(tensors) == len(unmerged_tensors)
|
||
|
for t, tl in zip(tensors, unmerged_tensors):
|
||
|
assert len(tl) == 1
|
||
|
assert torch.equal(t, tl[0])
|
||
|
|
||
|
|
||
|
def test_split_tp_param_2d_reverse() -> None:
|
||
|
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2])
|
||
|
orig_tensor = torch.rand(4, 6)
|
||
|
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||
|
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||
|
assert len(tensors) == len(split_tensors)
|
||
|
for t, st in zip(tensors, split_tensors):
|
||
|
assert torch.equal(t, st)
|
||
|
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||
|
assert len(tensors) == len(unmerged_tensors)
|
||
|
for t, tl in zip(tensors, unmerged_tensors):
|
||
|
assert len(tl) == 1
|
||
|
assert torch.equal(t, tl[0])
|
||
|
|
||
|
|
||
|
def test_unmerge_param_hybrid() -> None:
|
||
|
redist_meta = ParamRedistMeta(2,
|
||
|
6,
|
||
|
tp_shard_dims=[1, 0],
|
||
|
tp_num_parts=[3, 2],
|
||
|
zero_start_dp_rank=0,
|
||
|
zero_offsets=[0, 1])
|
||
|
orig_tensor = torch.rand(4, 6)
|
||
|
tensors = [
|
||
|
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
|
||
|
for chunk in t.contiguous().reshape(-1).split([1, 3])
|
||
|
]
|
||
|
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||
|
assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2
|
||
|
for tp_rank in range(6):
|
||
|
for dp_rank in range(2):
|
||
|
assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank])
|
||
|
|
||
|
|
||
|
def test_unmerge_param_dummy() -> None:
|
||
|
redist_meta = ParamRedistMeta(1, 1)
|
||
|
orig_tensor = torch.rand(4, 6)
|
||
|
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||
|
assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1
|
||
|
assert torch.equal(orig_tensor, unmerged_tensors[0][0])
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test_flatten_zero_param_even()
|
||
|
test_flatten_zero_param_uneven()
|
||
|
test_split_tp_param_1d_row()
|
||
|
test_split_tp_param_1d_col()
|
||
|
test_split_tp_param_2d()
|
||
|
test_split_tp_param_2d_reverse()
|
||
|
test_unmerge_param_hybrid()
|
||
|
test_unmerge_param_dummy()
|