@ -1,7 +1,6 @@
from contextlib import nullcontext
from contextlib import nullcontext
from typing import Optional
from typing import Optional
import pytest
import torch
import torch
import torch . distributed as dist
import torch . distributed as dist
@ -12,13 +11,7 @@ from colossalai.fx import is_compatible_with_meta
from colossalai . lazy . lazy_init import LazyInitContext
from colossalai . lazy . lazy_init import LazyInitContext
from colossalai . nn . optimizer import HybridAdam
from colossalai . nn . optimizer import HybridAdam
from colossalai . tensor . colo_parameter import ColoParameter
from colossalai . tensor . colo_parameter import ColoParameter
from colossalai . testing import (
from colossalai . testing import clear_cache_before_run , parameterize , rerun_if_address_is_in_use , spawn
clear_cache_before_run ,
parameterize ,
rerun_if_address_is_in_use ,
skip_if_not_enough_gpus ,
spawn ,
)
from tests . kit . model_zoo import COMMON_MODELS , IS_FAST_TEST , model_zoo
from tests . kit . model_zoo import COMMON_MODELS , IS_FAST_TEST , model_zoo
@ -177,12 +170,5 @@ def test_gemini_plugin(early_stop: bool = True):
spawn ( run_dist , 4 , early_stop = early_stop )
spawn ( run_dist , 4 , early_stop = early_stop )
@pytest.mark.largedist
@skip_if_not_enough_gpus ( 8 )
@rerun_if_address_is_in_use ( )
def test_gemini_plugin_3d ( early_stop : bool = True ) :
spawn ( run_dist , 8 , early_stop = early_stop )
if __name__ == " __main__ " :
if __name__ == " __main__ " :
test_gemini_plugin ( early_stop = False )
test_gemini_plugin ( early_stop = False )