ColossalAI/tests/test_zero_data_parallel/test_sharded_optim.py

169 lines
5.6 KiB
Python

import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp
from colossalai.zero import ShardedOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.utils import free_port
from functools import partial
from common import allclose
from tests.components_to_test.registry import non_distributed_component_funcs
def check_completely_equal(a, b):
"""
This function checks if two tensors are completely equal
"""
assert torch.all(a == b), f'a = {a}, b = {b}'
def check_sharded_param_consistency():
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
pattern
we use these prefixes to differentiate the zero stage
oss: partition optimizer states
pg: partition gradients and optimizer states
"""
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name)
model_builder, train_dataloader, *_ = get_components_func()
# create model
oss_model = model_builder(checkpoint=True).cuda().half()
pg_model = copy.deepcopy(oss_model)
# create optimizer
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
oss_optimizer = ShardedOptimizer(oss_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0)
pg_optimizer = ShardedOptimizer(pg_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=1,
clip_grad_norm=0.0)
# create
data, label = next(iter(train_dataloader))
input_data = data.cuda().half()
# forward
oss_output = oss_model(input_data)
pg_output = pg_model(input_data)
check_completely_equal(oss_output, pg_output)
# backward
oss_optimizer.backward(oss_output.mean().float())
pg_optimizer.backward(pg_output.mean().float())
# check grad
# as this param is small, the backward reduction
# will not be fired
for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
check_completely_equal(oss_param.grad, pg_param.grad)
# step
oss_optimizer.sync_grad()
pg_optimizer.sync_grad()
# step
oss_optimizer.step()
pg_optimizer.step()
# check updated param
for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
check_completely_equal(oss_param, pg_param)
def check_sharded_optim_against_torch_ddp():
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
2. torch: use torch DDP and fp32 parameters
We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance.
"""
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name)
model_builder, train_dataloader, *_ = get_components_func()
# create model
zero_model = model_builder(checkpoint=True).cuda()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.half()
torch_model = DDP(torch_model.cuda())
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = ShardedOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
clip_grad_norm=0.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
# create
input_data, _ = next(iter(train_dataloader))
input_data = input_data.cuda()
# zero-dp forward
zero_output = zero_model(input_data.half())
# torch-ddp forward
torch_output = torch_model(input_data)
allclose(zero_output, torch_output.half())
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
torch_output.mean().backward()
# check grad
for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
allclose(oss_param.grad, torch_param.grad.half())
# zero-dp step
zero_optimizer.sync_grad()
zero_optimizer.step()
# torch ddp step
torch_optimizer.step()
# check updated param
for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
allclose(oss_param, torch_param.half())
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_sharded_optim_against_torch_ddp()
check_sharded_param_consistency()
@pytest.mark.dist
def test_sharded_optim():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim()