|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
from functools import partial |
|
|
|
|
from typing import Callable |
|
|
|
|
|
|
|
|
|
import pytest |
|
|
|
|
import torch |
|
|
|
@ -13,7 +14,7 @@ from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chu
|
|
|
|
|
from colossalai.gemini.gemini_mgr import GeminiManager |
|
|
|
|
from colossalai.nn.optimizer import HybridAdam |
|
|
|
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer |
|
|
|
|
from colossalai.nn.parallel import ZeroDDP |
|
|
|
|
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper |
|
|
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use |
|
|
|
|
from colossalai.utils import free_port |
|
|
|
|
from colossalai.utils.cuda import get_current_device |
|
|
|
@ -36,9 +37,35 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
|
|
|
|
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def multi_chunk_init(model: torch.nn.Module, placement_policy: str): |
|
|
|
|
world_size = dist.get_world_size() |
|
|
|
|
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) |
|
|
|
|
config_dict[world_size]['chunk_size'] = 5000 |
|
|
|
|
config_dict[world_size]['keep_gathered'] = False |
|
|
|
|
if placement_policy != 'cuda': |
|
|
|
|
init_device = torch.device('cpu') |
|
|
|
|
else: |
|
|
|
|
init_device = None |
|
|
|
|
chunk_manager = ChunkManager(config_dict, init_device=init_device) |
|
|
|
|
gemini_manager = GeminiManager(placement_policy, chunk_manager) |
|
|
|
|
model = ZeroDDP(model, gemini_manager, pin_memory=True) |
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def single_chunk_init(model: torch.nn.Module, placement_policy: str): |
|
|
|
|
gemini_config = dict( |
|
|
|
|
device=get_current_device(), |
|
|
|
|
placement_policy=placement_policy, |
|
|
|
|
pin_memory=True, |
|
|
|
|
) |
|
|
|
|
model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config) |
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) |
|
|
|
|
@parameterize('model_name', ['gpt2']) |
|
|
|
|
def exam_inference(placement_policy, model_name: str): |
|
|
|
|
@parameterize('model_init_func', [single_chunk_init, multi_chunk_init]) |
|
|
|
|
def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable): |
|
|
|
|
set_seed(19360226) |
|
|
|
|
get_components_func = non_distributed_component_funcs.get_callable(model_name) |
|
|
|
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() |
|
|
|
@ -56,18 +83,7 @@ def exam_inference(placement_policy, model_name: str):
|
|
|
|
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()): |
|
|
|
|
p.data.copy_(torch_p.data) |
|
|
|
|
|
|
|
|
|
world_size = torch.distributed.get_world_size() |
|
|
|
|
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) |
|
|
|
|
config_dict[world_size]['chunk_size'] = 5000 |
|
|
|
|
config_dict[world_size]['keep_gathered'] = False |
|
|
|
|
if placement_policy != 'cuda': |
|
|
|
|
init_device = torch.device('cpu') |
|
|
|
|
else: |
|
|
|
|
init_device = None |
|
|
|
|
chunk_manager = ChunkManager(config_dict, init_device=init_device) |
|
|
|
|
gemini_manager = GeminiManager(placement_policy, chunk_manager) |
|
|
|
|
model = ZeroDDP(model, gemini_manager, pin_memory=True) |
|
|
|
|
|
|
|
|
|
model = model_init_func(model, placement_policy) |
|
|
|
|
optimizer = HybridAdam(model.parameters(), lr=1e-3) |
|
|
|
|
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) |
|
|
|
|
|
|
|
|
|