mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
62 lines
1.3 KiB
62 lines
1.3 KiB
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import colossalai
|
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|
from colossalai.zero import LowLevelZeroOptimizer
|
|
|
|
|
|
class MlpModel(nn.Module):
|
|
def __init__(self):
|
|
super(MlpModel, self).__init__()
|
|
self.linear1 = nn.Linear(123, 253)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
return x
|
|
|
|
|
|
DEL_CALLED = False
|
|
|
|
|
|
class TestLowLevelZeroOptimizer(LowLevelZeroOptimizer):
|
|
def __del__(self):
|
|
super().__del__()
|
|
global DEL_CALLED
|
|
DEL_CALLED = True
|
|
|
|
|
|
def exam_mem_leak(world_size):
|
|
"""
|
|
In this test, we test whether del will be called after the optimizer
|
|
is out of scope.
|
|
"""
|
|
# create models
|
|
zero_model = MlpModel().cuda()
|
|
|
|
# we only test stage 1 here
|
|
# in `check_sharded_param_consistency.py`, we will test whether
|
|
# level 1 and 2 will produce exactly the same results
|
|
zero_optimizer = TestLowLevelZeroOptimizer(torch.optim.SGD(zero_model.parameters(), lr=1))
|
|
|
|
del zero_optimizer
|
|
|
|
assert DEL_CALLED
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
|
|
|
exam_mem_leak(world_size=world_size)
|
|
|
|
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
def test_zero_1_2():
|
|
spawn(run_dist, 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_zero_1_2()
|