[test] polish zero related unitest (#351)

pull/394/head
Jiarui Fang 2022-03-10 09:57:26 +08:00 committed by Frank Lee
parent 534e0bb118
commit cb34cd384d
5 changed files with 75 additions and 123 deletions

View File

@ -0,0 +1,19 @@
import torch
from colossalai.zero.sharded_model import ShardedModelV2
import copy
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
"""
copy param of the ShardedModelV2 to other_model.
Note the other_model has to be the same as self.
"""
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
assert hasattr(zero_param, 'col_attr')
shard_flag = zero_param.col_attr.data.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.col_attr.data])
param.data = copy.deepcopy(zero_param.col_attr.data.payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.col_attr.data])

View File

@ -3,8 +3,10 @@ from functools import partial
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint
from colossalai.zero.sharded_model import ShardedModelV2
LOGGER = get_dist_logger()
@ -20,6 +22,21 @@ CONFIG = dict(fp16=dict(mode=None,),
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)))
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)

View File

@ -3,81 +3,70 @@
import copy
from functools import partial
import pytest
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding
from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.zero.sharded_model.utils import col_model_deepcopy
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(data)
loss = criterion(y, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
# with no criterion
def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
loss = model(data, label)
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def run_dist(rank, world_size, port):
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = TensorShardStrategy()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model(checkpoint=True).half().cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
if dist.get_world_size() > 1:
model = DDP(model)
model_builder, train_dataloader, _, _, criterion = get_components_func()
if use_zero_init_ctx:
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
else:
model = model_builder(checkpoint=True).half().cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
model = DDP(model)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
if i > 3:
break
if criterion is None:
data, label = data.cuda(), label.cuda()
run_fwd_bwd_no_criterion(model, data, label, False)
run_fwd_bwd_no_criterion(zero_model, data, label, False)
else:
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False)
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, enable_autocast)
run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
check_grads_padding(model, zero_model, loose=True)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4])
def test_shard_model_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("enable_autocast", [True])
@pytest.mark.parametrize("use_zero_init_ctx", [True])
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
use_zero_init_ctx=use_zero_init_ctx,
enable_autocast=enable_autocast)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_model_v2(world_size=2)
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True)

View File

@ -1,73 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads, check_grads_padding
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(data)
loss = criterion(y, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18']
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
shard_strategy = TensorShardStrategy()
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=shard_strategy, shard_param=True):
zero_model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
zero_model = zero_model()
model = copy.deepcopy(zero_model)
zero_model = ShardedModelV2(zero_model, shard_strategy)
model_state_dict = zero_model.state_dict()
for n, p in model.named_parameters():
p.data = model_state_dict[n]
model = model.half().cuda()
if dist.get_world_size() > 1:
model = DDP(model)
for i, (data, label) in enumerate(train_dataloader):
if i > 2:
break
data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False)
if dist.get_world_size() > 1:
check_grads_padding(model, zero_model, loose=True)
else:
check_grads(model, zero_model, loose=True)
@pytest.mark.dist
def test_shard_model_v2():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_model_v2()

View File

@ -78,7 +78,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4])
@pytest.mark.parametrize("world_size", [1, 2])
def test_sharded_optim_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)