|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
from contextlib import nullcontext |
|
|
|
|
from typing import Optional |
|
|
|
|
import pytest |
|
|
|
|
|
|
|
|
|
import pytest |
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
@ -11,8 +11,6 @@ from colossalai.booster.plugin import GeminiPlugin
|
|
|
|
|
from colossalai.fx import is_compatible_with_meta |
|
|
|
|
from colossalai.lazy.lazy_init import LazyInitContext |
|
|
|
|
from colossalai.nn.optimizer import HybridAdam |
|
|
|
|
from colossalai.tensor.d_tensor.api import clear_layout_converter |
|
|
|
|
from colossalai.shardformer.layer.utils import Randomizer |
|
|
|
|
from colossalai.tensor.colo_parameter import ColoParameter |
|
|
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
|
|
|
|
from tests.kit.model_zoo import model_zoo |
|
|
|
@ -26,7 +24,13 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
|
|
|
|
|
ctx = nullcontext() |
|
|
|
|
extra_dp_size = dist.get_world_size() // (zero_size * tp_size) |
|
|
|
|
enable_all_optimization = True if tp_size > 1 else False |
|
|
|
|
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization) |
|
|
|
|
plugin = GeminiPlugin( |
|
|
|
|
max_norm=1.0, |
|
|
|
|
initial_scale=2**5, |
|
|
|
|
tp_size=tp_size, |
|
|
|
|
extra_dp_size=extra_dp_size, |
|
|
|
|
enable_all_optimization=enable_all_optimization, |
|
|
|
|
) |
|
|
|
|
booster = Booster(plugin=plugin) |
|
|
|
|
with ctx: |
|
|
|
|
model = model_fn() |
|
|
|
@ -66,7 +70,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
|
|
|
|
|
@parameterize("init_method", ["none"]) |
|
|
|
|
@parameterize("zero_size", [2]) |
|
|
|
|
@parameterize("tp_size", [2]) |
|
|
|
|
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1): |
|
|
|
|
def check_gemini_plugin( |
|
|
|
|
subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1 |
|
|
|
|
): |
|
|
|
|
"""check gemini plugin over model zoo |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -161,6 +167,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
|
|
|
|
|
def test_gemini_plugin(early_stop: bool = True): |
|
|
|
|
spawn(run_dist, 4, early_stop=early_stop) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.largedist |
|
|
|
|
@rerun_if_address_is_in_use() |
|
|
|
|
def test_gemini_plugin_3d(early_stop: bool = True): |
|
|
|
@ -168,4 +175,4 @@ def test_gemini_plugin_3d(early_stop: bool = True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
test_gemini_plugin(early_stop=False) |
|
|
|
|
test_gemini_plugin(early_stop=False) |
|
|
|
|