ColossalAI/tests/test_ddp/test_ddp_state_dict.py

73 lines
2.3 KiB
Python

import copy
from collections import OrderedDict
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.parallel import ColoDDP
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()):
assert k1 == k2
if t1.device != t2.device:
temp_t2 = t2.to(t1.device)
else:
temp_t2 = t2
assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2)
def init_ddp(module: torch.nn.Module) -> ColoDDP:
pg = ProcessGroup()
return ColoDDP(module, process_group=pg)
def run_ddp_state_dict():
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = init_ddp(model)
torch_state_dict = torch_model.state_dict()
for param in model.parameters():
if isinstance(param, ColoParameter):
assert param.get_process_group() is not None
model.load_state_dict(torch_state_dict)
for param in model.parameters():
if isinstance(param, ColoParameter):
assert param.get_process_group() is not None
state_dict = model.state_dict()
check_state_dict_equal(torch_state_dict, state_dict)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_ddp_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_state_dict(2)