ColossalAI/tests/test_elixir/test_search/test_simple.py

49 lines
1.5 KiB
Python

from copy import deepcopy
import torch
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import simple_search
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
def step_fn(model, inp):
model(**inp).backward()
@run_on_environment_flag('ELX')
def test_simple_search():
model_fn, data_fn = TEST_MODELS.get('small')
model = model_fn()
data = data_fn()
sr = simple_search(model,
1,
split_number=5,
shard_device=gpu_device(),
prefetch=True,
verbose=True,
inp=data,
step_fn=step_fn)
chunk_plans = deepcopy(sr.param_chunk_plans)
private_plan = chunk_plans.pop(0)
assert private_plan.name_list == ['embed.weight']
assert private_plan.chunk_size == 320
assert private_plan.kwargs.get('shard_device') == gpu_device()
assert chunk_plans[0].name_list == ['norm1.weight', 'norm1.bias']
assert chunk_plans[1].name_list == ['mlp.proj1.weight', 'mlp.proj1.bias']
assert chunk_plans[2].name_list == ['mlp.proj2.weight', 'mlp.proj2.bias']
assert chunk_plans[3].name_list == ['norm2.weight']
assert chunk_plans[4].name_list == ['norm2.bias']
for plan in chunk_plans:
assert plan.chunk_size == 1088
assert plan.kwargs.get('shard_device') == gpu_device()
if __name__ == '__main__':
test_simple_search()