[unitest] polish zero config in unittest (#438)

pull/442/head
Jiarui Fang 2022-03-17 10:20:53 +08:00 committed by GitHub
parent 640a6cd304
commit 17b8274f8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 25 deletions

View File

@ -10,6 +10,18 @@ from colossalai.zero.sharded_model import ShardedModelV2
LOGGER = get_dist_logger() LOGGER = get_dist_logger()
_ZERO_OPTIMIZER_CONFIG = dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))
_ZERO_OFFLOAD_OPTIMIZER_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False)
_ZERO_OFFLOAD_PARAM_CONFIG = dict(device='cpu', pin_memory=True, buffer_count=5, buffer_size=1e8, max_in_cpu=1e9)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict(
optimzer=_ZERO_OPTIMIZER_CONFIG,
offload_optimizer_config=_ZERO_OFFLOAD_OPTIMIZER_CONFIG,
offload_param_config=_ZERO_OFFLOAD_PARAM_CONFIG,
),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
CONFIG = dict(fp16=dict(mode=None,), CONFIG = dict(fp16=dict(mode=None,),
zero=dict(level=3, zero=dict(level=3,
verbose=False, verbose=False,

View File

@ -19,10 +19,8 @@ def run_dist(rank, world_size, port):
# as this model has sync batch normalization # as this model has sync batch normalization
# need to configure cudnn deterministic so that # need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled # randomness of convolution layers will be disabled
colossalai.launch(config=dict( zero_config = dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)))
zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3))), colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False),
cudnn_determinstic=True,
cudnn_benchmark=False),
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
host='localhost', host='localhost',

View File

@ -8,32 +8,21 @@ import pytest
import colossalai import colossalai
from colossalai.utils import free_port from colossalai.utils import free_port
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import check_sharded_params_padding, ZERO_PARALLEL_CONFIG
from common import check_sharded_params_padding
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
_config = dict(fp16=dict(mode=None,), colossalai.launch(config=ZERO_PARALLEL_CONFIG,
zero=dict(optimzer=dict(optimizer_type=torch.optim.Adam, optimizer_config=dict(lr=1e-3)), rank=rank,
offload_optimizer_config=dict(device='cpu', world_size=world_size,
pin_memory=True, host='localhost',
buffer_count=5, port=port,
fast_init=False), backend='nccl')
offload_param_config=dict(device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9)),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
colossalai.launch(config=_config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') test_models = ['repeated_computed_layers', 'resnet18', 'bert']
# FIXME revert back
# test_models = ['repeated_computed_layers', 'resnet18', 'bert']
test_models = ['bert']
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
@ -65,8 +54,8 @@ def run_dist(rank, world_size, port):
output = engine(data) output = engine(data)
loss = engine.criterion(output, label) loss = engine.criterion(output, label)
torch_model(data, label) torch_output = torch_model(data)
torch_loss = engine.criterion(output, label) torch_loss = engine.criterion(torch_output, label)
else: else:
loss = engine(data, label) loss = engine(data, label)
torch_loss = torch_model(data, label) torch_loss = torch_model(data, label)