mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
4.1 KiB
121 lines
4.1 KiB
import pytest |
|
import torch |
|
import torch.distributed as dist |
|
|
|
import colossalai |
|
from colossalai.tensor import ColoParameter |
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup |
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
|
from colossalai.utils import get_current_device |
|
from colossalai.zero.gemini import TensorState |
|
from colossalai.zero.gemini.chunk import Chunk |
|
|
|
|
|
def dist_sum(x): |
|
temp = torch.tensor([x], device=get_current_device()) |
|
dist.all_reduce(temp) |
|
return temp.item() |
|
|
|
|
|
def add_param(param_list, param_cp_list, *args, **kwargs): |
|
param = ColoParameter(torch.randn(*args, **kwargs)) |
|
param_list.append(param) |
|
param_cp_list.append(param.clone()) |
|
|
|
|
|
def check_equal(param, param_cp): |
|
if param.device != param_cp.device: |
|
temp = param.data.to(param_cp.device) |
|
else: |
|
temp = param.data |
|
return torch.equal(temp, param_cp.data) |
|
|
|
|
|
@parameterize('init_device', [None, torch.device('cpu')]) |
|
@parameterize('keep_gathered', [True, False]) |
|
@parameterize('pin_memory', [True, False]) |
|
def exam_chunk_basic(init_device, keep_gathered, pin_memory): |
|
world_size = torch.distributed.get_world_size() |
|
pg = ColoProcessGroup() |
|
my_chunk = Chunk(chunk_size=1024, |
|
process_group=pg, |
|
dtype=torch.float32, |
|
init_device=init_device, |
|
cpu_shard_init=True, |
|
keep_gathered=keep_gathered, |
|
pin_memory=pin_memory) |
|
|
|
param_list = [] |
|
param_cp_list = [] |
|
|
|
add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') |
|
add_param(param_list, param_cp_list, 4, 4) |
|
add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') |
|
add_param(param_list, param_cp_list, 1, 1, 5) |
|
|
|
for param in param_list: |
|
my_chunk.append_tensor(param) |
|
assert my_chunk.utilized_size == 597 |
|
for param, param_cp in zip(param_list, param_cp_list): |
|
check_equal(param, param_cp) |
|
my_chunk.close_chunk() |
|
|
|
if keep_gathered is False: |
|
assert my_chunk.cpu_shard.size(0) == 1024 // world_size |
|
assert my_chunk.device_type == 'cpu' |
|
assert my_chunk.can_move |
|
my_chunk.shard_move(get_current_device()) |
|
else: |
|
assert my_chunk.cuda_global_chunk.size(0) == 1024 |
|
assert my_chunk.device_type == 'cuda' |
|
assert not my_chunk.can_move |
|
|
|
assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size |
|
flag = my_chunk.has_inf_or_nan |
|
assert not flag, "has_inf_or_nan is {}".format(flag) |
|
|
|
my_chunk.access_chunk() |
|
assert my_chunk.device_type == 'cuda' |
|
for param, param_cp in zip(param_list, param_cp_list): |
|
check_equal(param, param_cp) |
|
|
|
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 |
|
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) |
|
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 3 |
|
assert my_chunk.tensor_state_cnter[TensorState.COMPUTE] == 1 |
|
assert not my_chunk.can_release |
|
|
|
for param in param_list: |
|
my_chunk.tensor_trans_state(param, TensorState.COMPUTE) |
|
my_chunk.tensor_trans_state(param, TensorState.HOLD_AFTER_BWD) |
|
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) |
|
|
|
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 |
|
assert my_chunk.can_reduce |
|
my_chunk.reduce() |
|
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 |
|
|
|
if keep_gathered is False: |
|
assert my_chunk.cuda_shard.size(0) == 1024 // world_size |
|
assert my_chunk.device_type == 'cuda' |
|
assert my_chunk.can_move |
|
else: |
|
assert my_chunk.cuda_global_chunk.size(0) == 1024 |
|
assert my_chunk.device_type == 'cuda' |
|
assert not my_chunk.can_move |
|
|
|
|
|
def run_dist(rank, world_size, port): |
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
|
exam_chunk_basic() |
|
|
|
|
|
@pytest.mark.dist |
|
@pytest.mark.parametrize('world_size', [1, 2, 4]) |
|
@rerun_if_address_is_in_use() |
|
def test_chunk_function(world_size): |
|
spawn(run_dist, world_size) |
|
|
|
|
|
if __name__ == '__main__': |
|
test_chunk_function(4)
|
|
|