mirror of https://github.com/hpcaitech/ColossalAI
[test] optimized zero data parallel test (#452)
parent
cfcc8271f3
commit
f27d801a13
|
@ -16,47 +16,29 @@ from torchvision.datasets import CIFAR10
|
||||||
from torchvision.models import resnet18
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 2
|
||||||
IMG_SIZE = 224
|
|
||||||
NUM_CLASSES = 10
|
NUM_CLASSES = 10
|
||||||
|
|
||||||
CONFIG = dict(
|
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
|
||||||
parallel=dict(
|
|
||||||
pipeline=dict(size=1),
|
|
||||||
tensor=dict(size=1, mode=None)
|
|
||||||
),
|
|
||||||
clip_grad_norm=1.0,
|
clip_grad_norm=1.0,
|
||||||
gradient_accumulation=4
|
gradient_accumulation=4)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_no_pipeline(rank, world_size, port):
|
def run_no_pipeline(rank, world_size, port):
|
||||||
|
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
config=CONFIG,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
host='localhost',
|
|
||||||
port=port,
|
|
||||||
backend='nccl'
|
|
||||||
)
|
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
model = resnet18(num_classes=10)
|
model = resnet18(num_classes=10)
|
||||||
|
|
||||||
# build dataloaders
|
# build dataloaders
|
||||||
train_dataset = CIFAR10(
|
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||||
root=Path(os.environ['DATA']),
|
|
||||||
download=True,
|
download=True,
|
||||||
transform=transforms.Compose(
|
transform=transforms.Compose([
|
||||||
[
|
|
||||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||||
]
|
]))
|
||||||
)
|
|
||||||
)
|
|
||||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
|
@ -67,12 +49,10 @@ def run_no_pipeline(rank, world_size, port):
|
||||||
optimizer = Adam(model.parameters(), lr=0.001)
|
optimizer = Adam(model.parameters(), lr=0.001)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
engine, train_dataloader, *args = colossalai.initialize(
|
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
train_dataloader=train_dataloader
|
train_dataloader=train_dataloader)
|
||||||
)
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
param_track = []
|
param_track = []
|
||||||
|
|
|
@ -17,35 +17,7 @@ from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
|
||||||
from colossalai.zero.sharded_model import ShardedModel
|
from colossalai.zero.sharded_model import ShardedModel
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from colossalai.testing import parameterize
|
||||||
|
|
||||||
class Enumerator:
|
|
||||||
|
|
||||||
def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None:
|
|
||||||
self.arg_names = arg_names
|
|
||||||
self.enums = Enumerator.all_enumerate(arg_values)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.enums)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return {name: self.enums[idx][i] for i, name in enumerate(self.arg_names)}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def all_enumerate(args: List[tuple]):
|
|
||||||
num_states = reduce(op.mul, map(lambda xs: len(xs), args))
|
|
||||||
idxs = [0] * len(args)
|
|
||||||
states = []
|
|
||||||
for _ in range(num_states):
|
|
||||||
states.append(tuple(args[j][idx] for j, idx in enumerate(idxs)))
|
|
||||||
if len(states) == num_states:
|
|
||||||
break
|
|
||||||
i = 0
|
|
||||||
while idxs[i] + 1 == len(args[i]):
|
|
||||||
idxs[i] = 0
|
|
||||||
i += 1
|
|
||||||
idxs[i] += 1
|
|
||||||
return states
|
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_wrapper(module, enable=True):
|
def checkpoint_wrapper(module, enable=True):
|
||||||
|
@ -125,6 +97,10 @@ def check_params(model, zero_model, loose=False):
|
||||||
assert allclose(p, zero_p, loose=loose)
|
assert allclose(p, zero_p, loose=loose)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('checkpoint', [False, True])
|
||||||
|
@parameterize('fp16', [False, True])
|
||||||
|
@parameterize('offload', [False, True])
|
||||||
|
@parameterize('norm_type', [1.0, 2.0, float('inf')])
|
||||||
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
||||||
model = Net(checkpoint=checkpoint).cuda()
|
model = Net(checkpoint=checkpoint).cuda()
|
||||||
zero_model = copy.deepcopy(model)
|
zero_model = copy.deepcopy(model)
|
||||||
|
@ -155,15 +131,6 @@ def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
args = ['checkpoint', 'fp16', 'offload', 'norm_type']
|
|
||||||
arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))]
|
|
||||||
arg_enumerator = Enumerator(args, arg_values)
|
|
||||||
|
|
||||||
for kwargs in arg_enumerator:
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
print(kwargs)
|
|
||||||
check_config(**kwargs)
|
|
||||||
check_config()
|
check_config()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,11 +15,12 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
from colossalai.testing import parameterize
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, init_device, shard_strategy):
|
@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
|
def run_model_test(init_device, shard_strategy):
|
||||||
for get_components_func in non_distributed_component_funcs:
|
for get_components_func in non_distributed_component_funcs:
|
||||||
model_builder, _, _, _, _ = get_components_func()
|
model_builder, _, _, _, _ = get_components_func()
|
||||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||||
|
@ -43,19 +44,18 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_model_test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 4])
|
@pytest.mark.parametrize("world_size", [1, 4])
|
||||||
@pytest.mark.parametrize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
|
def test_zero_init_context(world_size):
|
||||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
def test_zero_init_context(world_size, init_device, shard_strategy):
|
|
||||||
run_func = partial(run_dist,
|
|
||||||
world_size=world_size,
|
|
||||||
port=free_port(),
|
|
||||||
init_device=init_device,
|
|
||||||
shard_strategy=shard_strategy)
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
||||||
test_zero_init_context(4, torch.device('cpu'), BucketTensorShardStrategy)
|
test_zero_init_context(4)
|
||||||
|
|
|
@ -20,12 +20,13 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||||
|
from colossalai.testing import parameterize
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy):
|
@parameterize("enable_autocast", [True])
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
@parameterize("use_zero_init_ctx", [True])
|
||||||
logger = get_dist_logger()
|
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
logger.set_level('DEBUG')
|
def run_model_test(enable_autocast, use_zero_init_ctx, shard_strategy, logger):
|
||||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||||
shard_strategy = shard_strategy()
|
shard_strategy = shard_strategy()
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
|
@ -66,20 +67,19 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s
|
||||||
# logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
|
# logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.set_level('DEBUG')
|
||||||
|
run_model_test(logger=logger)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
@pytest.mark.parametrize("enable_autocast", [True])
|
def test_shard_model_v2(world_size):
|
||||||
@pytest.mark.parametrize("use_zero_init_ctx", [True])
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
|
||||||
def test_shard_model_v2(world_size, use_zero_init_ctx, enable_autocast, shard_strategy):
|
|
||||||
run_func = partial(run_dist,
|
|
||||||
world_size=world_size,
|
|
||||||
port=free_port(),
|
|
||||||
use_zero_init_ctx=use_zero_init_ctx,
|
|
||||||
enable_autocast=enable_autocast,
|
|
||||||
shard_strategy=shard_strategy)
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_shard_model_v2(world_size=2, use_zero_init_ctx=True, enable_autocast=True, shard_strategy=TensorShardStrategy)
|
test_shard_model_v2(world_size=2)
|
||||||
|
|
|
@ -15,10 +15,11 @@ from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
||||||
|
from colossalai.testing import parameterize
|
||||||
|
|
||||||
|
|
||||||
def _run_shard_tensor(rank, world_size, port, shard_strategy):
|
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
def run_shard_tensor_with_strategy(shard_strategy, world_size):
|
||||||
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||||
assert list(t.origin_shape) == [world_size * 2, 3]
|
assert list(t.origin_shape) == [world_size * 2, 3]
|
||||||
assert list(t.shape) == [world_size * 2, 3]
|
assert list(t.shape) == [world_size * 2, 3]
|
||||||
|
@ -32,11 +33,15 @@ def _run_shard_tensor(rank, world_size, port, shard_strategy):
|
||||||
assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
|
assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
|
||||||
|
|
||||||
|
|
||||||
|
def _run_shard_tensor(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_shard_tensor_with_strategy(world_size=world_size)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
def test_shard_tensor(world_size):
|
||||||
def test_shard_tensor(world_size, shard_strategy):
|
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
|
||||||
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,7 +127,7 @@ def test_init_shard_param(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_shard_tensor(2, TensorShardStrategy)
|
test_shard_tensor(2)
|
||||||
test_shard_param(2)
|
test_shard_param(2)
|
||||||
test_shard_param_v2(2)
|
test_shard_param_v2(2)
|
||||||
test_init_shard_param(4)
|
test_init_shard_param(4)
|
||||||
|
|
|
@ -14,7 +14,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from colossalai.nn.optimizer import CPUAdam
|
from colossalai.nn.optimizer import CPUAdam
|
||||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||||
|
from colossalai.testing import parameterize
|
||||||
from common import CONFIG, check_sharded_params_padding
|
from common import CONFIG, check_sharded_params_padding
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,8 +36,10 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
def _run_dist(rank, world_size, port, cpu_offload, shard_strategy, use_cpuadam):
|
@parameterize("cpu_offload", [True, False])
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
@parameterize("use_cpuadam", [True, False])
|
||||||
|
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
|
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy, use_cpuadam):
|
||||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||||
shard_strategy = shard_strategy()
|
shard_strategy = shard_strategy()
|
||||||
|
|
||||||
|
@ -76,36 +78,18 @@ def _run_dist(rank, world_size, port, cpu_offload, shard_strategy, use_cpuadam):
|
||||||
assert not has_inf_or_nan(param)
|
assert not has_inf_or_nan(param)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
_run_test_sharded_optim_v2()
|
||||||
|
|
||||||
|
|
||||||
# use_cpuadam = True can be used with cpu_offload = False
|
# use_cpuadam = True can be used with cpu_offload = False
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
@pytest.mark.parametrize("cpu_offload", [False])
|
def test_sharded_optim_v2(world_size):
|
||||||
@pytest.mark.parametrize("use_cpuadam", [False])
|
run_func = partial(_run_dist, world_size=world_size, port=free_port())
|
||||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
|
||||||
def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy, use_cpuadam):
|
|
||||||
run_func = partial(_run_dist,
|
|
||||||
world_size=world_size,
|
|
||||||
port=free_port(),
|
|
||||||
cpu_offload=cpu_offload,
|
|
||||||
shard_strategy=shard_strategy,
|
|
||||||
use_cpuadam=use_cpuadam)
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
|
||||||
@pytest.mark.parametrize("cpu_offload", [True])
|
|
||||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
|
||||||
@pytest.mark.parametrize("use_cpuadam", [True, False])
|
|
||||||
def test_sharded_optim_v2_cpu_adam(world_size, cpu_offload, shard_strategy, use_cpuadam):
|
|
||||||
run_func = partial(_run_dist,
|
|
||||||
world_size=world_size,
|
|
||||||
port=free_port(),
|
|
||||||
cpu_offload=cpu_offload,
|
|
||||||
shard_strategy=shard_strategy,
|
|
||||||
use_cpuadam=use_cpuadam)
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_sharded_optim_v2_cpu_adam(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy, use_cpuadam=True)
|
test_sharded_optim_v2(world_size=2)
|
||||||
|
|
|
@ -12,12 +12,12 @@ from colossalai.utils import free_port
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
from colossalai.testing import parameterize
|
||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, shard_strategy):
|
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
def run_zero_state_dict(shard_strategy):
|
||||||
test_models = ['repeated_computed_layers', 'resnet18']
|
test_models = ['repeated_computed_layers', 'resnet18']
|
||||||
shard_strategy = shard_strategy()
|
shard_strategy = shard_strategy()
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
|
@ -31,11 +31,15 @@ def run_dist(rank, world_size, port, shard_strategy):
|
||||||
assert torch.equal(val, zero_state_dict[key])
|
assert torch.equal(val, zero_state_dict[key])
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_zero_state_dict()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
@pytest.mark.parametrize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
def test_zero_state_dict(world_size):
|
||||||
def test_zero_state_dict(world_size, shard_strategy):
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), shard_strategy=shard_strategy)
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue