mirror of https://github.com/hpcaitech/ColossalAI
[test] optimized zero data parallel test (#452)
parent
cfcc8271f3
commit
f27d801a13
|
@ -16,47 +16,29 @@ from torchvision.datasets import CIFAR10
|
|||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 16
|
||||
IMG_SIZE = 224
|
||||
BATCH_SIZE = 2
|
||||
NUM_CLASSES = 10
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
),
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
|
||||
clip_grad_norm=1.0,
|
||||
gradient_accumulation=4
|
||||
)
|
||||
gradient_accumulation=4)
|
||||
|
||||
|
||||
def run_no_pipeline(rank, world_size, port):
|
||||
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
)
|
||||
]))
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
|
@ -67,12 +49,10 @@ def run_no_pipeline(rank, world_size, port):
|
|||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(
|
||||
model=model,
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader
|
||||
)
|
||||
train_dataloader=train_dataloader)
|
||||
logger = get_dist_logger()
|
||||
rank = torch.distributed.get_rank()
|
||||
param_track = []
|
||||
|
|
|
@ -17,35 +17,7 @@ from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
|
|||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
|
||||
class Enumerator:
|
||||
|
||||
def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None:
|
||||
self.arg_names = arg_names
|
||||
self.enums = Enumerator.all_enumerate(arg_values)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.enums)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {name: self.enums[idx][i] for i, name in enumerate(self.arg_names)}
|
||||
|
||||
@staticmethod
|
||||
def all_enumerate(args: List[tuple]):
|
||||
num_states = reduce(op.mul, map(lambda xs: len(xs), args))
|
||||
idxs = [0] * len(args)
|
||||
states = []
|
||||
for _ in range(num_states):
|
||||
states.append(tuple(args[j][idx] for j, idx in enumerate(idxs)))
|
||||
if len(states) == num_states:
|
||||
break
|
||||
i = 0
|
||||
while idxs[i] + 1 == len(args[i]):
|
||||
idxs[i] = 0
|
||||
i += 1
|
||||
idxs[i] += 1
|
||||
return states
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
|
@ -125,6 +97,10 @@ def check_params(model, zero_model, loose=False):
|
|||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
@parameterize('checkpoint', [False, True])
|
||||
@parameterize('fp16', [False, True])
|
||||
@parameterize('offload', [False, True])
|
||||
@parameterize('norm_type', [1.0, 2.0, float('inf')])
|
||||
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
|
@ -155,15 +131,6 @@ def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
|||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload', 'norm_type']
|
||||
arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))]
|
||||
arg_enumerator = Enumerator(args, arg_values)
|
||||
|
||||
for kwargs in arg_enumerator:
|
||||
if dist.get_rank() == 0:
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
check_config()
|
||||
|
||||
|
||||
|
|
|
@ -15,11 +15,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||
|
||||
from common import CONFIG
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
|
||||
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(init_device, shard_strategy):
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||
|
@ -43,19 +44,18 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
|
|||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_model_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
|
||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def test_zero_init_context(world_size, init_device, shard_strategy):
|
||||
run_func = partial(run_dist,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
init_device=init_device,
|
||||
shard_strategy=shard_strategy)
|
||||
def test_zero_init_context(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
||||
test_zero_init_context(4, torch.device('cpu'), BucketTensorShardStrategy)
|
||||
test_zero_init_context(4)
|
||||
|
|
|
@ -20,12 +20,13 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
logger = get_dist_logger()
|
||||
logger.set_level('DEBUG')
|
||||
@parameterize("enable_autocast", [True])
|
||||
@parameterize("use_zero_init_ctx", [True])
|
||||
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(enable_autocast, use_zero_init_ctx, shard_strategy, logger):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
shard_strategy = shard_strategy()
|
||||
for model_name in test_models:
|
||||
|
@ -66,20 +67,19 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
|
|||
# logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
logger = get_dist_logger()
|
||||
logger.set_level('DEBUG')
|
||||
run_model_test(logger=logger)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("enable_autocast", [True])
|
||||
@pytest.mark.parametrize("use_zero_init_ctx", [True])
|
||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast, shard_strategy):
|
||||
run_func = partial(run_dist,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
use_zero_init_ctx=use_zero_init_ctx,
|
||||
enable_autocast=enable_autocast,
|
||||
shard_strategy=shard_strategy)
|
||||
def test_shard_model_v2(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True, shard_strategy=TensorShardStrategy)
|
||||
test_shard_model_v2(world_size=2)
|
||||
|
|
|
@ -15,10 +15,11 @@ from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
|
|||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
def _run_shard_tensor(rank, world_size, port, shard_strategy):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_shard_tensor_with_strategy(shard_strategy, world_size):
|
||||
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||
assert list(t.origin_shape) == [world_size * 2, 3]
|
||||
assert list(t.shape) == [world_size * 2, 3]
|
||||
|
@ -32,11 +33,15 @@ def _run_shard_tensor(rank, world_size, port, shard_strategy):
|
|||
assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
|
||||
|
||||
|
||||
def _run_shard_tensor(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_shard_tensor_with_strategy(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def test_shard_tensor(world_size, shard_strategy):
|
||||
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
|
||||
def test_shard_tensor(world_size):
|
||||
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
@ -122,7 +127,7 @@ def test_init_shard_param(world_size):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_tensor(2, TensorShardStrategy)
|
||||
test_shard_tensor(2)
|
||||
test_shard_param(2)
|
||||
test_shard_param_v2(2)
|
||||
test_init_shard_param(4)
|
||||
|
|
|
@ -14,7 +14,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
|
||||
from colossalai.testing import parameterize
|
||||
from common import CONFIG, check_sharded_params_padding
|
||||
|
||||
|
||||
|
@ -36,8 +36,10 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
|||
optimizer.step()
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port, cpu_offload, shard_strategy, use_cpuadam):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
@parameterize("use_cpuadam", [True, False])
|
||||
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy, use_cpuadam):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
shard_strategy = shard_strategy()
|
||||
|
||||
|
@ -76,36 +78,18 @@ def _run_dist(rank, world_size, port, cpu_offload, shard_strategy, use_cpuadam):
|
|||
assert not has_inf_or_nan(param)
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_test_sharded_optim_v2()
|
||||
|
||||
|
||||
# use_cpuadam = True can be used with cpu_offload = False
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("cpu_offload", [False])
|
||||
@pytest.mark.parametrize("use_cpuadam", [False])
|
||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy, use_cpuadam):
|
||||
run_func = partial(_run_dist,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
cpu_offload=cpu_offload,
|
||||
shard_strategy=shard_strategy,
|
||||
use_cpuadam=use_cpuadam)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("cpu_offload", [True])
|
||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
@pytest.mark.parametrize("use_cpuadam", [True, False])
|
||||
def test_sharded_optim_v2_cpu_adam(world_size, cpu_offload, shard_strategy, use_cpuadam):
|
||||
run_func = partial(_run_dist,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
cpu_offload=cpu_offload,
|
||||
shard_strategy=shard_strategy,
|
||||
use_cpuadam=use_cpuadam)
|
||||
def test_sharded_optim_v2(world_size):
|
||||
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_v2_cpu_adam(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy, use_cpuadam=True)
|
||||
test_sharded_optim_v2(world_size=2)
|
||||
|
|
|
@ -12,12 +12,12 @@ from colossalai.utils import free_port
|
|||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from colossalai.testing import parameterize
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, shard_strategy):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_zero_state_dict(shard_strategy):
|
||||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
shard_strategy = shard_strategy()
|
||||
for model_name in test_models:
|
||||
|
@ -31,11 +31,15 @@ def run_dist(rank, world_size, port, shard_strategy):
|
|||
assert torch.equal(val, zero_state_dict[key])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_zero_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def test_zero_state_dict(world_size, shard_strategy):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
|
||||
def test_zero_state_dict(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue