ColossalAI/tests/test_zero_data_parallel/test_sharded_optim.py

179 lines
5.4 KiB
Python
Raw Normal View History

import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.zero import ShardedOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.utils import free_port
from functools import partial
def check_equal(a, b):
"""
This function checks if two tensors are equal within tolerance
"""
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
def check_completely_equal(a, b):
"""
This function checks if two tensors are completely equal
"""
assert torch.all(a == b), f'a = {a}, b = {b}'
def check_sharded_param_consistency():
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
pattern
we use these prefixes to differentiate the zero stage
oss: partition optimizer states
pg: partition gradients and optimizer states
"""
# create layers
oss_linear1 = nn.Linear(128, 256)
oss_linear2 = nn.Linear(256, 512)
# create model
oss_model = nn.Sequential(oss_linear1, oss_linear2)
pg_model = copy.deepcopy(oss_model)
oss_model = oss_model.cuda().half()
pg_model = pg_model.cuda().half()
# create optimizer
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
oss_optimizer = ShardedOptimizer(oss_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0)
pg_optimizer = ShardedOptimizer(pg_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=1,
clip_grad_norm=0.0)
# create
input_data = torch.rand(32, 128).cuda().half()
# forward
oss_output = oss_model(input_data)
pg_output = pg_model(input_data)
check_completely_equal(oss_output, pg_output)
# backward
oss_optimizer.backward(oss_output.mean().float())
pg_optimizer.backward(pg_output.mean().float())
# check grad
# as this param is small, the backward reduction
# will not be fired
oss_linear1_grad = oss_model[0].weight.grad
oss_linear2_grad = oss_model[1].weight.grad
pg_linear1_grad = pg_model[0].weight.grad
pg_linear2_grad = pg_model[1].weight.grad
check_completely_equal(oss_linear1_grad, pg_linear1_grad)
check_completely_equal(oss_linear2_grad, pg_linear2_grad)
# step
oss_optimizer.sync_grad()
pg_optimizer.sync_grad()
# step
oss_optimizer.step()
pg_optimizer.step()
# check updated param
check_completely_equal(oss_model[0].weight, pg_model[0].weight)
check_completely_equal(oss_model[1].weight, pg_model[1].weight)
def check_sharded_optim_against_torch_ddp():
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
2. torch: use torch DDP and fp32 parameters
We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance.
"""
# create layer
zero_linear1 = nn.Linear(128, 256)
zero_linear2 = nn.Linear(256, 512)
# create model
zero_model = nn.Sequential(zero_linear1, zero_linear2)
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half()
torch_model = DDP(torch_model.cuda())
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = ShardedOptimizer(zero_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
# create
input_data = torch.rand(32, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.half())
# torch-ddp forward
torch_output = torch_model(input_data)
check_equal(zero_output, torch_output)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
torch_output.mean().backward()
# check grad
zero_linear1_grad = zero_model[0].weight.grad
zero_linear2_grad = zero_model[1].weight.grad
torch_linear1_grad = torch_model.module[0].weight.grad
torch_linear2_grad = torch_model.module[1].weight.grad
check_equal(zero_linear1_grad, torch_linear1_grad)
check_equal(zero_linear2_grad, torch_linear2_grad)
# zero-dp step
zero_optimizer.sync_grad()
zero_optimizer.step()
# torch ddp step
torch_optimizer.step()
# check updated param
check_equal(zero_model[0].weight, torch_model.module[0].weight)
check_equal(zero_model[1].weight, torch_model.module[1].weight)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_sharded_optim_against_torch_ddp()
check_sharded_param_consistency()
@pytest.mark.dist
def test_sharded_optim():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim()