[Tensor] test model check results for a simple net (#887)

pull/888/head
Jiarui Fang 2022-04-27 12:00:18 +08:00 committed by GitHub
parent 72cdc06875
commit a0e5971692
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 44 additions and 10 deletions

View File

@ -2,6 +2,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
@ -9,15 +10,30 @@ from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from functools import partial
import random
import os
import numpy as np
def run_simple_net():
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def run_1d_row_tp():
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
@ -26,6 +42,11 @@ def run_simple_net():
]
spec = TensorSpec(parallel_action_list)
set_seed(1)
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
# A naive way to set spec for all weights in Linear
for name, p in named_params_with_colotensor(model):
if not isinstance(p, ColoTensor):
@ -33,15 +54,16 @@ def run_simple_net():
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
p.set_spec(spec)
model.cuda()
for param in named_params_with_colotensor(model):
print(param)
model = model.cuda()
for i, (data, label) in enumerate(train_dataloader):
data = data.to(get_current_device())
label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Bcast rank0 data to all processes
if criterion:
output = model(data)
loss = criterion(output, label)
@ -49,22 +71,34 @@ def run_simple_net():
output = model(data, label)
loss = output
print(loss.torch_tensor())
# For reference
if rank == 0:
if criterion:
output_torch = model_torch(data)
loss_torch = criterion(output_torch, label)
else:
output_torch = model_torch(data, label)
loss_torch = output_torch
if rank == 0:
# print(loss.torch_tensor().item())
# print('loss torch', loss_torch.item())
assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)
loss.backward()
if rank == 0:
loss_torch.backward()
if i > 5:
break
# TODO(jzy) check the results with col.nn.Linear?
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_simple_net()
run_1d_row_tp()
@pytest.mark.skip
@pytest.mark.dist
@parameterize('world_size', [1, 4])
@rerun_if_address_is_in_use()