|
|
|
@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|
|
|
|
from colossalai.utils.cuda import get_current_device |
|
|
|
|
from colossalai.utils import free_port |
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext |
|
|
|
|
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup |
|
|
|
|
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ColoTensor, ColoTensorSpec |
|
|
|
|
from colossalai.nn.parallel.data_parallel import ColoDDP |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
@ -21,18 +21,20 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|
|
|
|
|
|
|
|
|
def init_1d_row_spec(model, pg: ProcessGroup): |
|
|
|
|
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
|
|
|
|
with DistSpecManager.no_grad(): |
|
|
|
|
for n, p in model.named_parameters(): |
|
|
|
|
if 'weight' in n and 'ln' not in n: |
|
|
|
|
p.set_tensor_spec(*tensor_spec) |
|
|
|
|
|
|
|
|
|
for n, p in model.named_parameters(): |
|
|
|
|
p.set_process_group(pg) |
|
|
|
|
if 'weight' in n and 'ln' not in n: |
|
|
|
|
p.set_tensor_spec(*tensor_spec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_col_spec(model, pg: ProcessGroup): |
|
|
|
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
|
|
|
|
with DistSpecManager.no_grad(): |
|
|
|
|
for n, p in model.named_parameters(): |
|
|
|
|
if 'ln' not in n and ('weight' in n or 'bias' in n): |
|
|
|
|
p.set_tensor_spec(*spec) |
|
|
|
|
|
|
|
|
|
for n, p in model.named_parameters(): |
|
|
|
|
p.set_process_group(pg) |
|
|
|
|
if 'ln' not in n and ('weight' in n or 'bias' in n): |
|
|
|
|
p.set_tensor_spec(*spec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_param_equal(model, torch_model, pg: ProcessGroup): |
|
|
|
@ -48,6 +50,7 @@ def check_grad_equal(model, torch_model, pg: ProcessGroup):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_gpt(init_spec_func, use_ddp): |
|
|
|
|
set_seed(13234) |
|
|
|
|
world_size = torch.distributed.get_world_size() |
|
|
|
|
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1)) |
|
|
|
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2') |
|
|
|
@ -67,14 +70,16 @@ def run_gpt(init_spec_func, use_ddp):
|
|
|
|
|
model = ColoDDP(model, process_group=pg) |
|
|
|
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()): |
|
|
|
|
torch_p.data.copy_(p) |
|
|
|
|
|
|
|
|
|
init_spec_func(model, pg) |
|
|
|
|
check_param_equal(model, torch_model, pg) |
|
|
|
|
model.train() |
|
|
|
|
torch_model.train() |
|
|
|
|
set_seed(pg.tp_local_rank()) |
|
|
|
|
torch.distributed.barrier() |
|
|
|
|
|
|
|
|
|
for i, (input_ids, attn_mask) in enumerate(train_dataloader): |
|
|
|
|
logits = model(input_ids, attn_mask) |
|
|
|
|
colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) |
|
|
|
|
logits = model(colo_input, attn_mask) |
|
|
|
|
torch_logits = torch_model(input_ids, attn_mask) |
|
|
|
|
assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}" |
|
|
|
|
loss = criterion(logits, input_ids) |
|
|
|
@ -95,14 +100,13 @@ def run_dist(rank, world_size, port, use_ddp):
|
|
|
|
|
tp_world_size = world_size // 2 if use_ddp else world_size |
|
|
|
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) |
|
|
|
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
|
|
|
|
# run_gpt(init_1d_row_spec, use_ddp) |
|
|
|
|
run_gpt(init_1d_row_spec, use_ddp) |
|
|
|
|
run_gpt(init_1d_col_spec, use_ddp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist |
|
|
|
|
@pytest.mark.skip("under development") |
|
|
|
|
@pytest.mark.parametrize('world_size', [1, 4]) |
|
|
|
|
@pytest.mark.parametrize('use_ddp', [False, True]) |
|
|
|
|
@pytest.mark.parametrize('use_ddp', [False]) |
|
|
|
|
@rerun_if_address_is_in_use() |
|
|
|
|
def test_gpt(world_size, use_ddp): |
|
|
|
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) |
|
|
|
@ -110,4 +114,4 @@ def test_gpt(world_size, use_ddp):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
test_gpt(4, True) |
|
|
|
|
test_gpt(4, False) |
|
|
|
|