mirror of https://github.com/hpcaitech/ColossalAI
tests: add `sub_dp_group` test
parent
9291f07964
commit
6ceaf4f1f8
|
@ -8,12 +8,28 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.testing import parameterize, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.utils import conditional_context
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
rtol = None
|
||||
atol = None
|
||||
if dtype is torch.float16:
|
||||
rtol = 5e-2
|
||||
atol = 5e-4
|
||||
elif dtype is torch.bfloat16:
|
||||
rtol = 4e-3
|
||||
atol = 4e-3
|
||||
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype)
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MlpModel, self).__init__()
|
||||
|
@ -26,7 +42,9 @@ class MlpModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def exam_zero_1_2_grad_acc():
|
||||
@parameterize("sub_dp_size", [1, 2])
|
||||
def exam_zero_1_2_grad_acc(sub_dp_size: int):
|
||||
assert torch.distributed.get_world_size() % sub_dp_size == 0
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(2009)
|
||||
device = get_accelerator().get_current_device()
|
||||
|
@ -37,10 +55,20 @@ def exam_zero_1_2_grad_acc():
|
|||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
||||
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
||||
zero1_optimizer = LowLevelZeroOptimizer(
|
||||
zero1_optimizer, overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, verbose=True
|
||||
zero1_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
verbose=True,
|
||||
sub_dp_size=sub_dp_size,
|
||||
)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(
|
||||
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, clip_grad_norm=1.0
|
||||
zero2_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
sub_dp_size=sub_dp_size,
|
||||
)
|
||||
# create data
|
||||
seed_all(2021 + local_rank)
|
||||
|
@ -51,7 +79,7 @@ def exam_zero_1_2_grad_acc():
|
|||
# zero-dp forward
|
||||
zero1_output = zero1_model(cur_data)
|
||||
zero2_output = zero2_model(cur_data)
|
||||
assert torch.equal(zero1_output, zero2_output)
|
||||
loose_close(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
|
@ -66,10 +94,13 @@ def exam_zero_1_2_grad_acc():
|
|||
|
||||
# check updated param
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
assert torch.equal(z1p.data, z2p.data)
|
||||
loose_close(z1p.data, z2p.data)
|
||||
|
||||
|
||||
def exam_zero_1_grad_acc(sync):
|
||||
@parameterize("no_sync", [True, False])
|
||||
@parameterize("sub_dp_size", [1, 2])
|
||||
def exam_zero_1_grad_acc(no_sync: bool, sub_dp_size: int):
|
||||
assert torch.distributed.get_world_size() % sub_dp_size == 0
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(2008)
|
||||
device = get_accelerator().get_current_device()
|
||||
|
@ -89,7 +120,11 @@ def exam_zero_1_grad_acc(sync):
|
|||
# in `check_sharded_param_consistency.py`, we will test whether
|
||||
# level 1 and 2 will produce exactly the same results
|
||||
zero_optimizer = LowLevelZeroOptimizer(
|
||||
zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, clip_grad_norm=1.0
|
||||
zero_optimizer,
|
||||
overlap_communication=False,
|
||||
reduce_bucket_size=262144,
|
||||
clip_grad_norm=1.0,
|
||||
sub_dp_size=sub_dp_size,
|
||||
)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
|
@ -108,38 +143,37 @@ def exam_zero_1_grad_acc(sync):
|
|||
# torch-ddp fwd and bwd
|
||||
with conditional_context(torch_model.no_sync(), no_sync):
|
||||
torch_output = torch_model(cur_data)
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
loose_close(zero_output, torch_output)
|
||||
torch_output.sum().backward()
|
||||
|
||||
if check_flag:
|
||||
# check grad
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
assert torch.equal(p.grad, z1p.grad)
|
||||
loose_close(p.grad, z1p.grad)
|
||||
|
||||
fwd_bwd_func(sync, input_data1, sync)
|
||||
fwd_bwd_func(no_sync, input_data1, no_sync)
|
||||
fwd_bwd_func(False, input_data2, False)
|
||||
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
||||
|
||||
zero_optimizer.step()
|
||||
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
||||
torch_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
# print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data)))
|
||||
assert_close(p.data, z1p.data)
|
||||
loose_close(p.data, z1p.data)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
exam_zero_1_grad_acc(sync=True)
|
||||
exam_zero_1_grad_acc(sync=False)
|
||||
exam_zero_1_grad_acc()
|
||||
exam_zero_1_2_grad_acc()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_grad_accumulation():
|
||||
spawn(run_dist, 2)
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size):
|
|||
return splited_grad
|
||||
|
||||
|
||||
def exam_zero_1_2():
|
||||
@parameterize("sub_dp_size", [1, 2])
|
||||
def exam_zero_1_2(sub_dp_size: int):
|
||||
"""
|
||||
In this test, we want to test whether zero stage 1 and 2
|
||||
deliver the same numerical results despite different communication
|
||||
|
@ -62,6 +63,7 @@ def exam_zero_1_2():
|
|||
pg: partition gradients and optimizer states
|
||||
|
||||
"""
|
||||
assert torch.distributed.get_world_size() % sub_dp_size == 0
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(2001)
|
||||
|
||||
|
@ -73,10 +75,10 @@ def exam_zero_1_2():
|
|||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
||||
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
||||
zero1_optimizer = LowLevelZeroOptimizer(
|
||||
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True
|
||||
zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True, sub_dp_size=sub_dp_size
|
||||
)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(
|
||||
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128
|
||||
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128, sub_dp_size=sub_dp_size
|
||||
)
|
||||
# create data
|
||||
seed_all(2001 + local_rank)
|
||||
|
@ -94,7 +96,7 @@ def exam_zero_1_2():
|
|||
z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0)
|
||||
z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0)
|
||||
for z1g, z2g in zip(z1g_list, z2g_list):
|
||||
assert torch.equal(z1g, z2g)
|
||||
loose_close(z1g, z2g)
|
||||
|
||||
# step
|
||||
zero1_optimizer.step()
|
||||
|
@ -102,12 +104,13 @@ def exam_zero_1_2():
|
|||
|
||||
# check updated param
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
assert torch.equal(z1p.data, z2p.data)
|
||||
loose_close(z1p.data, z2p.data)
|
||||
|
||||
|
||||
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
||||
@parameterize("master_weights", [True, False])
|
||||
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
||||
@parameterize("sub_dp_size", [1, 2])
|
||||
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool, sub_dp_size: int):
|
||||
"""
|
||||
In this test, two pairs of model and optimizers are created.
|
||||
1. zero: use sharded optimizer and fp16 parameters
|
||||
|
@ -116,6 +119,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||
We feed these two sets of models with the same input and check if the
|
||||
differences in model output and updated parameters are within tolerance.
|
||||
"""
|
||||
assert world_size % sub_dp_size == 0
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(1453)
|
||||
|
||||
|
@ -137,6 +141,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||
initial_scale=1,
|
||||
reduce_bucket_size=1024 * 1024,
|
||||
master_weights=master_weights,
|
||||
sub_dp_size=sub_dp_size,
|
||||
)
|
||||
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
@ -162,7 +167,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
if p.grad is not None:
|
||||
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p))
|
||||
torch_grad_list = split_ddp_grad(p.grad, world_size)
|
||||
torch_grad_list = split_ddp_grad(p.grad, world_size // sub_dp_size)
|
||||
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
|
||||
loose_close(zero_grad, torch_grad, dtype=dtype)
|
||||
|
||||
|
@ -187,7 +192,7 @@ def run_dist(rank, world_size, port):
|
|||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_1_2():
|
||||
spawn(run_dist, 2)
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue