mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
118 lines
3.2 KiB
118 lines
3.2 KiB
import copy |
|
|
|
import pytest |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.testing import assert_close |
|
|
|
import colossalai |
|
from colossalai.testing import rerun_if_address_is_in_use, spawn |
|
from colossalai.testing.random import seed_all |
|
from colossalai.zero import LowLevelZeroOptimizer |
|
|
|
|
|
class MlpModel(nn.Module): |
|
def __init__(self): |
|
super(MlpModel, self).__init__() |
|
self.linear1 = nn.Linear(12, 24) |
|
self.linear2 = nn.Linear(24, 12) |
|
|
|
def forward(self, x): |
|
x = self.linear1(x) |
|
x = self.linear2(x) |
|
return x |
|
|
|
|
|
def loose_close(a, b, dtype: torch.dtype = torch.float32): |
|
rtol = None |
|
atol = None |
|
if dtype is torch.float16: |
|
rtol = 5e-2 |
|
atol = 5e-4 |
|
elif dtype is torch.bfloat16: |
|
rtol = 4e-3 |
|
atol = 4e-3 |
|
|
|
a = a.detach().to(dtype) |
|
b = b.detach().to(dtype).to(a.device) |
|
|
|
assert_close(a, b, rtol=rtol, atol=atol) |
|
|
|
|
|
def exam_zero_1_torch_ddp_ckpt(): |
|
""" |
|
We examine the state_dict of zero and DDP. |
|
Moreover, we examine the zero's loading checkpoint of a torch ckpt. |
|
""" |
|
local_rank = torch.distributed.get_rank() |
|
seed_all(1453) |
|
|
|
# create models |
|
torch_model = MlpModel().cuda() |
|
zero_model = copy.deepcopy(torch_model) |
|
|
|
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() |
|
|
|
# create optimizer |
|
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) |
|
|
|
# we only test stage 1 here |
|
# the state dicts of stage 1 and stage 2 are the same |
|
zero_optimizer = LowLevelZeroOptimizer( |
|
zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144 |
|
) |
|
|
|
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) |
|
|
|
seed_all(1453 + local_rank) |
|
# create |
|
input_data = torch.rand(4, 12).cuda() |
|
|
|
# forward |
|
zero_output = zero_model(input_data) |
|
torch_output = torch_model(input_data) |
|
|
|
# backward |
|
zero_optimizer.backward(zero_output.mean().float()) |
|
torch_output.mean().backward() |
|
|
|
# step |
|
zero_optimizer.step() |
|
torch_optimizer.step() |
|
|
|
torch_state_dict = torch_optimizer.state_dict() |
|
zero_state_dict = zero_optimizer.state_dict() |
|
|
|
# examine the original state dict |
|
for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()): |
|
for t_v, z_v in zip(torch_state.values(), zero_state.values()): |
|
loose_close(t_v, z_v) |
|
|
|
# empty the optimzer state |
|
zero_optimizer.optim.state = [] |
|
|
|
# zero load a torch checkpoint |
|
zero_optimizer.load_state_dict(copy.deepcopy(torch_state_dict)) |
|
zero_state_dict = zero_optimizer.state_dict() |
|
|
|
# examine the loaded state dict |
|
for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()): |
|
for t_v, z_v in zip(torch_state.values(), zero_state.values()): |
|
loose_close(t_v, z_v) |
|
|
|
|
|
def run_dist(rank, world_size, port): |
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") |
|
|
|
exam_zero_1_torch_ddp_ckpt() |
|
|
|
|
|
@pytest.mark.dist |
|
@rerun_if_address_is_in_use() |
|
def test_zero_ckpt(): |
|
spawn(run_dist, 2) |
|
|
|
|
|
if __name__ == "__main__": |
|
test_zero_ckpt()
|
|
|