tests: add `sub_dp_group` test

pull/5817/head
Wenhao Chen 2024-04-01 14:51:36 +08:00 committed by アマデウス
parent 9291f07964
commit 6ceaf4f1f8
2 changed files with 63 additions and 24 deletions

View File

@ -8,12 +8,28 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.testing import spawn from colossalai.testing import parameterize, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context from colossalai.utils import conditional_context
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
a = a.detach().to(dtype)
b = b.detach().to(dtype)
assert_close(a, b, rtol=rtol, atol=atol)
class MlpModel(nn.Module): class MlpModel(nn.Module):
def __init__(self): def __init__(self):
super(MlpModel, self).__init__() super(MlpModel, self).__init__()
@ -26,7 +42,9 @@ class MlpModel(nn.Module):
return x return x
def exam_zero_1_2_grad_acc(): @parameterize("sub_dp_size", [1, 2])
def exam_zero_1_2_grad_acc(sub_dp_size: int):
assert torch.distributed.get_world_size() % sub_dp_size == 0
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2009) seed_all(2009)
device = get_accelerator().get_current_device() device = get_accelerator().get_current_device()
@ -37,10 +55,20 @@ def exam_zero_1_2_grad_acc():
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer( zero1_optimizer = LowLevelZeroOptimizer(
zero1_optimizer, overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, verbose=True zero1_optimizer,
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
verbose=True,
sub_dp_size=sub_dp_size,
) )
zero2_optimizer = LowLevelZeroOptimizer( zero2_optimizer = LowLevelZeroOptimizer(
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, clip_grad_norm=1.0 zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0,
sub_dp_size=sub_dp_size,
) )
# create data # create data
seed_all(2021 + local_rank) seed_all(2021 + local_rank)
@ -51,7 +79,7 @@ def exam_zero_1_2_grad_acc():
# zero-dp forward # zero-dp forward
zero1_output = zero1_model(cur_data) zero1_output = zero1_model(cur_data)
zero2_output = zero2_model(cur_data) zero2_output = zero2_model(cur_data)
assert torch.equal(zero1_output, zero2_output) loose_close(zero1_output, zero2_output)
# zero-dp backward # zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float()) zero1_optimizer.backward(zero1_output.sum().float())
@ -66,10 +94,13 @@ def exam_zero_1_2_grad_acc():
# check updated param # check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.equal(z1p.data, z2p.data) loose_close(z1p.data, z2p.data)
def exam_zero_1_grad_acc(sync): @parameterize("no_sync", [True, False])
@parameterize("sub_dp_size", [1, 2])
def exam_zero_1_grad_acc(no_sync: bool, sub_dp_size: int):
assert torch.distributed.get_world_size() % sub_dp_size == 0
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2008) seed_all(2008)
device = get_accelerator().get_current_device() device = get_accelerator().get_current_device()
@ -89,7 +120,11 @@ def exam_zero_1_grad_acc(sync):
# in `check_sharded_param_consistency.py`, we will test whether # in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results # level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer( zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, clip_grad_norm=1.0 zero_optimizer,
overlap_communication=False,
reduce_bucket_size=262144,
clip_grad_norm=1.0,
sub_dp_size=sub_dp_size,
) )
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
@ -108,38 +143,37 @@ def exam_zero_1_grad_acc(sync):
# torch-ddp fwd and bwd # torch-ddp fwd and bwd
with conditional_context(torch_model.no_sync(), no_sync): with conditional_context(torch_model.no_sync(), no_sync):
torch_output = torch_model(cur_data) torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output) loose_close(zero_output, torch_output)
torch_output.sum().backward() torch_output.sum().backward()
if check_flag: if check_flag:
# check grad # check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
assert torch.equal(p.grad, z1p.grad) loose_close(p.grad, z1p.grad)
fwd_bwd_func(sync, input_data1, sync) fwd_bwd_func(no_sync, input_data1, no_sync)
fwd_bwd_func(False, input_data2, False) fwd_bwd_func(False, input_data2, False)
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
zero_optimizer.step() zero_optimizer.step()
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
torch_optimizer.step() torch_optimizer.step()
# check updated param # check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data))) # print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data)))
assert_close(p.data, z1p.data) loose_close(p.data, z1p.data)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
exam_zero_1_grad_acc(sync=True) exam_zero_1_grad_acc()
exam_zero_1_grad_acc(sync=False)
exam_zero_1_2_grad_acc() exam_zero_1_2_grad_acc()
@pytest.mark.dist @pytest.mark.dist
def test_grad_accumulation(): def test_grad_accumulation():
spawn(run_dist, 2) spawn(run_dist, 4)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
return splited_grad return splited_grad
def exam_zero_1_2(): @parameterize("sub_dp_size", [1, 2])
def exam_zero_1_2(sub_dp_size: int):
""" """
In this test, we want to test whether zero stage 1 and 2 In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication deliver the same numerical results despite different communication
@ -62,6 +63,7 @@ def exam_zero_1_2():
pg: partition gradients and optimizer states pg: partition gradients and optimizer states
""" """
assert torch.distributed.get_world_size() % sub_dp_size == 0
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2001) seed_all(2001)
@ -73,10 +75,10 @@ def exam_zero_1_2():
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer( zero1_optimizer = LowLevelZeroOptimizer(
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True, sub_dp_size=sub_dp_size
) )
zero2_optimizer = LowLevelZeroOptimizer( zero2_optimizer = LowLevelZeroOptimizer(
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128 zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128, sub_dp_size=sub_dp_size
) )
# create data # create data
seed_all(2001 + local_rank) seed_all(2001 + local_rank)
@ -94,7 +96,7 @@ def exam_zero_1_2():
z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0) z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0) z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
for z1g, z2g in zip(z1g_list, z2g_list): for z1g, z2g in zip(z1g_list, z2g_list):
assert torch.equal(z1g, z2g) loose_close(z1g, z2g)
# step # step
zero1_optimizer.step() zero1_optimizer.step()
@ -102,12 +104,13 @@ def exam_zero_1_2():
# check updated param # check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.equal(z1p.data, z2p.data) loose_close(z1p.data, z2p.data)
@parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("master_weights", [True, False]) @parameterize("master_weights", [True, False])
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): @parameterize("sub_dp_size", [1, 2])
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool, sub_dp_size: int):
""" """
In this test, two pairs of model and optimizers are created. In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters 1. zero: use sharded optimizer and fp16 parameters
@ -116,6 +119,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
We feed these two sets of models with the same input and check if the We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance. differences in model output and updated parameters are within tolerance.
""" """
assert world_size % sub_dp_size == 0
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(1453) seed_all(1453)
@ -137,6 +141,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
initial_scale=1, initial_scale=1,
reduce_bucket_size=1024 * 1024, reduce_bucket_size=1024 * 1024,
master_weights=master_weights, master_weights=master_weights,
sub_dp_size=sub_dp_size,
) )
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
@ -162,7 +167,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
if p.grad is not None: if p.grad is not None:
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p)) zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
torch_grad_list = split_ddp_grad(p.grad, world_size) torch_grad_list = split_ddp_grad(p.grad, world_size // sub_dp_size)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
loose_close(zero_grad, torch_grad, dtype=dtype) loose_close(zero_grad, torch_grad, dtype=dtype)
@ -187,7 +192,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_zero_1_2(): def test_zero_1_2():
spawn(run_dist, 2) spawn(run_dist, 4)
if __name__ == "__main__": if __name__ == "__main__":