From a0e59716920b4225fe0c226a9618c46f8ff25f0d Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 27 Apr 2022 12:00:18 +0800 Subject: [PATCH] [Tensor] test model check results for a simple net (#887) --- .../{test_net_tp.py => test_model.py} | 54 +++++++++++++++---- 1 file changed, 44 insertions(+), 10 deletions(-) rename tests/test_tensor/{test_net_tp.py => test_model.py} (62%) diff --git a/tests/test_tensor/test_net_tp.py b/tests/test_tensor/test_model.py similarity index 62% rename from tests/test_tensor/test_net_tp.py rename to tests/test_tensor/test_model.py index b07b1fd18..cb885b152 100644 --- a/tests/test_tensor/test_net_tp.py +++ b/tests/test_tensor/test_model.py @@ -2,6 +2,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs import colossalai import pytest +import torch import torch.multiprocessing as mp from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device @@ -9,15 +10,30 @@ from colossalai.utils import free_port from colossalai.utils import ColoInitContext from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from functools import partial +import random +import os +import numpy as np -def run_simple_net(): +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def run_1d_row_tp(): # A simple net with two stacked nn.Linear get_components_func = non_distributed_component_funcs.get_callable('simple_net') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + set_seed(1) with ColoInitContext(device=get_current_device()): model = model_builder(checkpoint=True) @@ -26,6 +42,11 @@ def run_simple_net(): ] spec = TensorSpec(parallel_action_list) + set_seed(1) + if rank == 0: + model_torch = model_builder(checkpoint=True) + model_torch = model_torch.cuda() + # A naive way to set spec for all weights in Linear for name, p in named_params_with_colotensor(model): if not isinstance(p, ColoTensor): @@ -33,15 +54,16 @@ def run_simple_net(): if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name: p.set_spec(spec) - model.cuda() - - for param in named_params_with_colotensor(model): - print(param) + model = model.cuda() for i, (data, label) in enumerate(train_dataloader): data = data.to(get_current_device()) label = label.to(get_current_device()) + torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) + + # Bcast rank0 data to all processes if criterion: output = model(data) loss = criterion(output, label) @@ -49,22 +71,34 @@ def run_simple_net(): output = model(data, label) loss = output - print(loss.torch_tensor()) + # For reference + if rank == 0: + if criterion: + output_torch = model_torch(data) + loss_torch = criterion(output_torch, label) + else: + output_torch = model_torch(data, label) + loss_torch = output_torch + + if rank == 0: + # print(loss.torch_tensor().item()) + # print('loss torch', loss_torch.item()) + assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2) + loss.backward() + if rank == 0: + loss_torch.backward() if i > 5: break - # TODO(jzy) check the results with col.nn.Linear? - def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_simple_net() + run_1d_row_tp() -@pytest.mark.skip @pytest.mark.dist @parameterize('world_size', [1, 4]) @rerun_if_address_is_in_use()