Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

144 lines
4.5 KiB

import copy
import pytest
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing import spawn
from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context
from colossalai.zero import LowLevelZeroOptimizer
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank()
seed_all(2009)
# create model
zero1_model = MlpModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(
zero1_optimizer, overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, verbose=True
)
zero2_optimizer = LowLevelZeroOptimizer(
zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, clip_grad_norm=1.0
)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
zero1_output = zero1_model(cur_data)
zero2_output = zero2_model(cur_data)
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float())
zero2_optimizer.backward(zero2_output.sum().float())
fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False)
# step
zero1_optimizer.step()
zero2_optimizer.step()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_grad_acc(sync):
local_rank = torch.distributed.get_rank()
seed_all(2008)
# create models
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
seed_all(2008)
zero_model = zero_model.cuda()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, clip_grad_norm=1.0
)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
# create data
seed_all(2022 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(no_sync, cur_data, check_flag):
# zero1 fwd and bwd
with conditional_context(zero_optimizer.no_sync(), no_sync):
zero_output = zero_model(cur_data)
zero_optimizer.backward(zero_output.sum().float())
# torch-ddp fwd and bwd
with conditional_context(torch_model.no_sync(), no_sync):
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
assert torch.equal(p.grad, z1p.grad)
fwd_bwd_func(sync, input_data1, sync)
fwd_bwd_func(False, input_data2, False)
zero_optimizer.step()
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
torch_optimizer.step()
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data)))
assert_close(p.data, z1p.data)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
exam_zero_1_grad_acc(sync=True)
exam_zero_1_grad_acc(sync=False)
exam_zero_1_2_grad_acc()
@pytest.mark.dist
def test_grad_accumulation():
spawn(run_dist, 2)
if __name__ == "__main__":
test_grad_accumulation()