mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
62 lines
1.3 KiB
62 lines
1.3 KiB
5 months ago
|
import pytest
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
import colossalai
|
||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||
|
from colossalai.zero import LowLevelZeroOptimizer
|
||
|
|
||
|
|
||
|
class MlpModel(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(MlpModel, self).__init__()
|
||
|
self.linear1 = nn.Linear(123, 253)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.linear1(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
DEL_CALLED = False
|
||
|
|
||
|
|
||
|
class TestLowLevelZeroOptimizer(LowLevelZeroOptimizer):
|
||
|
def __del__(self):
|
||
|
super().__del__()
|
||
|
global DEL_CALLED
|
||
|
DEL_CALLED = True
|
||
|
|
||
|
|
||
|
def exam_mem_leak(world_size):
|
||
|
"""
|
||
|
In this test, we test whether del will be called after the optimizer
|
||
|
is out of scope.
|
||
|
"""
|
||
|
# create models
|
||
|
zero_model = MlpModel().cuda()
|
||
|
|
||
|
# we only test stage 1 here
|
||
|
# in `check_sharded_param_consistency.py`, we will test whether
|
||
|
# level 1 and 2 will produce exactly the same results
|
||
|
zero_optimizer = TestLowLevelZeroOptimizer(torch.optim.SGD(zero_model.parameters(), lr=1))
|
||
|
|
||
|
del zero_optimizer
|
||
|
|
||
|
assert DEL_CALLED
|
||
|
|
||
|
|
||
|
def run_dist(rank, world_size, port):
|
||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||
|
|
||
|
exam_mem_leak(world_size=world_size)
|
||
|
|
||
|
|
||
|
@pytest.mark.dist
|
||
|
@rerun_if_address_is_in_use()
|
||
|
def test_zero_1_2():
|
||
|
spawn(run_dist, 2)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test_zero_1_2()
|