diff --git a/colossalai/gemini/chunk/__init__.py b/colossalai/gemini/chunk/__init__.py index 8468a6815..86ff785f7 100644 --- a/colossalai/gemini/chunk/__init__.py +++ b/colossalai/gemini/chunk/__init__.py @@ -1,3 +1,4 @@ -from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk -from .manager import ChunkManager -from .search_utils import clasify_params, search_chunk_configuration +from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState +from .manager import ChunkManager +from .search_utils import clasify_params, search_chunk_configuration +from .utils import init_chunk_manager diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py index f309872a4..d7b5c7aa8 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/gemini/chunk/search_utils.py @@ -1,100 +1,108 @@ -import math -from typing import Dict, List -import numpy as np -import torch.nn as nn -from colossalai.tensor import ColoParameter - - -def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: - """Filter those parameters whose size is too large from others. - """ - params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)] - params_size_arr = np.array(params_size) - - std = np.std(params_size_arr) - mean = np.mean(params_size_arr) - upper_limit = mean + 3 * std - - for key in size_dict: - org_list = size_dict[key] - size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) - - -def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: - """Get unused byte for a certain chunk size. - """ - acc = 0 - left = 0 - for s in size_list: - if s > left: - acc += left - left = chunk_size - left -= s - return left + acc - - -def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: - params_dict: Dict[int, List[ColoParameter]] = dict() - for param in model.parameters(): - assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" - if getattr(param, '_ddp_to_ignore', False): - continue - - param_key = param.process_group.dp_world_size() - - if param_key not in params_dict: - params_dict[param_key] = [] - params_dict[param_key].append(param) - - return params_dict - - -def search_chunk_configuration( - model: nn.Module, - search_range_mb: float, - search_interval_byte: int, # hidden size is the best value for the interval - min_chunk_size_mb: float = 32, - filter_exlarge_params: bool = True) -> Dict: - search_range_byte = round(search_range_mb * 1024**2) - min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) - assert search_range_byte >= 0 - - params_dict = clasify_params(model) - config_dict: Dict[int, Dict] = dict() - - size_dict: Dict[int, List[int]] = dict() - for key in params_dict: - params_list = params_dict[key] - size_list = [p.numel() for p in params_list] - # let small parameters keep gathered in CUDA all the time - total_size = sum(size_list) - if total_size < min_chunk_size_byte: - config_dict[key] = dict(chunk_size=total_size, keep_gathered=True) - else: - size_dict[key] = size_list - - if filter_exlarge_params: - _filter_exlarge_params(model, size_dict) - - max_size = min_chunk_size_byte - for key in size_dict: - max_size = max(max_size, max(size_dict[key])) - start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) - - min_chunk_waste = float('+inf') - best_chunk_size = start_size - - for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): - temp_waste = 0 - for key in size_dict: - temp_waste += _get_unused_byte(size_dict[key], chunk_size) - if temp_waste < min_chunk_waste: - min_chunk_waste = temp_waste - best_chunk_size = chunk_size - - for key in params_dict: - if key in config_dict: - continue - config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False) - - return config_dict +import math +from typing import Dict, List, Tuple + +import numpy as np +import torch.nn as nn + +from colossalai.tensor import ColoParameter + + +def in_ddp(param: nn.Parameter) -> bool: + return not getattr(param, '_ddp_to_ignore', False) + + +def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: + """Filter those parameters whose size is too large from others. + """ + params_size = [p.numel() for p in model.parameters() if in_ddp(p)] + params_size_arr = np.array(params_size) + + std = np.std(params_size_arr) + mean = np.mean(params_size_arr) + upper_limit = mean + 3 * std + + for key in size_dict: + org_list = size_dict[key] + size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) + + +def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: + """Get unused byte for a certain chunk size. + """ + acc = 0 + left = 0 + for s in size_list: + if s > left: + acc += left + left = chunk_size + left -= s + return left + acc + + +def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: + """Clasify each parameter by its size of DP group. + """ + params_dict: Dict[int, List[ColoParameter]] = dict() + for param in model.parameters(): + assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" + if not in_ddp(param): + continue + + param_key = param.process_group.dp_world_size() + + if param_key not in params_dict: + params_dict[param_key] = [] + params_dict[param_key].append(param) + + return params_dict + + +def search_chunk_configuration( + model: nn.Module, + search_range_mb: float, + search_interval_byte: int, # hidden size is the best value for the interval + min_chunk_size_mb: float = 32, + filter_exlarge_params: bool = True) -> Tuple[Dict, int]: + search_range_byte = round(search_range_mb * 1024**2) + min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) + assert search_range_byte >= 0 + + params_dict = clasify_params(model) + config_dict: Dict[int, Dict] = dict() + + size_dict: Dict[int, List[int]] = dict() + for key in params_dict: + params_list = params_dict[key] + size_list = [p.numel() for p in params_list] + # let small parameters keep gathered in CUDA all the time + total_size = sum(size_list) + if total_size < min_chunk_size_byte: + config_dict[key] = dict(chunk_size=total_size, keep_gathered=True) + else: + size_dict[key] = size_list + + if filter_exlarge_params: + _filter_exlarge_params(model, size_dict) + + max_size = min_chunk_size_byte + for key in size_dict: + max_size = max(max_size, max(size_dict[key])) + start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) + + min_chunk_waste = float('+inf') + best_chunk_size = start_size + + for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): + temp_waste = 0 + for key in size_dict: + temp_waste += _get_unused_byte(size_dict[key], chunk_size) + if temp_waste < min_chunk_waste: + min_chunk_waste = temp_waste + best_chunk_size = chunk_size + + for key in params_dict: + if key in config_dict: + continue + config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False) + + return config_dict, min_chunk_waste diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/gemini/chunk/utils.py new file mode 100644 index 000000000..9d87129db --- /dev/null +++ b/colossalai/gemini/chunk/utils.py @@ -0,0 +1,58 @@ +from time import time +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.gemini.chunk import ChunkManager +from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration + + +def init_chunk_manager(model: nn.Module, + init_device: Optional[torch.device] = None, + hidden_dim: Optional[int] = None, + search_range_mb: Optional[float] = None, + min_chunk_size_mb: Optional[float] = None, + filter_exlarge_params: Optional[bool] = None) -> ChunkManager: + + kwargs_dict = dict() + + if hidden_dim: + search_interval_byte = hidden_dim + else: + search_interval_byte = 1024 # 1kb + kwargs_dict["search_interval_byte"] = search_interval_byte + + if search_range_mb: + kwargs_dict["search_range_mb"] = search_range_mb + + if min_chunk_size_mb: + kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb + + if filter_exlarge_params: + kwargs_dict["filter_exlarge_params"] = filter_exlarge_params + + params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)] + total_size = sum(params_sizes) / 1024**2 + + dist.barrier() + begine = time() + + config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict) + + dist.barrier() + end = time() + span_s = end - begine + wasted_size /= 1024**2 + + if dist.get_rank() == 0: + print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), + "used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), + "total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)), + sep='', + flush=True) + dist.barrier() + + chunk_manager = ChunkManager(config_dict, init_device) + return chunk_manager diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index d98018adf..2be962e1a 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -1,21 +1,23 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from functools import partial -from colossalai.nn.parallel import ColoDDP, ZeroDDP -from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Callable, Type -import torch.distributed as dist import os import random +from functools import partial +from typing import Callable, Type + import numpy as np +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.parallel import ColoDDP, ZeroDDP from colossalai.tensor import ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext def set_seed(seed): @@ -33,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: def init_ddpv2(module: torch.nn.Module) -> ZeroDDP: - chunk_config = search_chunk_configuration(module, 4, 1024) + chunk_config, _ = search_chunk_configuration(module, 4, 1024) chunk_manager = ChunkManager(chunk_config) gemini_manager = GeminiManager('cuda', chunk_manager) return ZeroDDP(module, gemini_manager) diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index 4b9694c0d..eb433f2c3 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -1,105 +1,104 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext - -from functools import partial -from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal -from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager -from colossalai.nn.parallel import ZeroDDP -from colossalai.testing import parameterize -from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.tensor import ProcessGroup -from tests.test_tensor.common_utils import debug_print - - -def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): - chunk_manager = model.chunk_manager - param_list = [p for p in model.parameters()] - chunk_list = chunk_manager.get_chunks(param_list) - for chunk in chunk_list: - chunk_manager.access_chunk(chunk) - - for (p0, p1) in zip(model.parameters(), torch_model.parameters()): - assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item()) - - -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): - optimizer.zero_grad() - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - optimizer.backward(loss) - return logits - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -def exam_gpt_fwd_bwd(placement_policy): - set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - torch_model = model_builder().cuda() - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - pg = ProcessGroup() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) - - model.eval() - torch_model.eval() - - set_seed(pg.dp_local_rank()) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): - if i > 0: - break - - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - model.backward(loss) - - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) - assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format( - torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits) - - check_grad(model, torch_model) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_gpt_fwd_bwd() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_gpt(1) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal + + +def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): + chunk_manager = model.chunk_manager + param_list = [p for p in model.parameters()] + chunk_list = chunk_manager.get_chunks(param_list) + for chunk in chunk_list: + chunk_manager.access_chunk(chunk) + + for (p0, p1) in zip(model.parameters(), torch_model.parameters()): + assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item()) + + +def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): + optimizer.zero_grad() + logits = model(input_ids, attn_mask) + logits = logits.float() + loss = criterion(logits, input_ids) + optimizer.backward(loss) + return logits + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +def exam_gpt_fwd_bwd(placement_policy): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + torch_model = model_builder().cuda() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + pg = ProcessGroup() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + + model.eval() + torch_model.eval() + + set_seed(pg.dp_local_rank()) + for i, (input_ids, attn_mask) in enumerate(train_dataloader): + if i > 0: + break + + logits = model(input_ids, attn_mask) + logits = logits.float() + loss = criterion(logits, input_ids) + model.backward(loss) + + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format( + torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits) + + check_grad(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_gpt_fwd_bwd() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gpt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gpt(1) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index 3c82258a5..62822f133 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -1,118 +1,116 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext - -from functools import partial -from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal -from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.nn.parallel import ZeroDDP -from colossalai.nn.optimizer import HybridAdam -from colossalai.zero import ZeroOptimizer -from colossalai.testing import parameterize -from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.gemini_mgr import GeminiManager -from tests.test_tensor.common_utils import debug_print - -from time import time -from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager - - -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): - zero_dict = model.state_dict(only_rank_0=False) - torch_dict = torch_model.state_dict() - - for key, value in torch_dict.items(): - # key is 'module.model.PARAMETER', so we truncate it - key = key[7:] - if key == 'model.lm_head.weight': - continue - assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) - - -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): - optimizer.zero_grad() - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - optimizer.backward(loss) - return logits - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -def exam_gpt_fwd_bwd(placement_policy): - set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - torch_model = model_builder().cuda() - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - - model.eval() - torch_model.eval() - - set_seed(dist.get_rank() * 3 + 128) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): - if i > 2: - break - - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) - assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) - # debug_print([0], zero_logits, torch_logits) - - zero_optim.step() - torch_optim.step() - - check_param(model, torch_model) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_gpt_fwd_bwd() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_gpt(1) +from functools import partial +from time import time + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ZeroOptimizer +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) + + +def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): + optimizer.zero_grad() + logits = model(input_ids, attn_mask) + logits = logits.float() + loss = criterion(logits, input_ids) + optimizer.backward(loss) + return logits + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +def exam_gpt_fwd_bwd(placement_policy): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + torch_model = model_builder().cuda() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + for i, (input_ids, attn_mask) in enumerate(train_dataloader): + if i > 2: + break + + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) + # debug_print([0], zero_logits, torch_logits) + + zero_optim.step() + torch_optim.step() + + check_param(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_gpt_fwd_bwd() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gpt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gpt(1) diff --git a/tests/test_gemini/update/test_search.py b/tests/test_gemini/update/test_search.py index 6655c3e39..e0b4e207f 100644 --- a/tests/test_gemini/update/test_search.py +++ b/tests/test_gemini/update/test_search.py @@ -1,66 +1,65 @@ -import pytest - -from functools import partial - -import torch -import torch.multiprocessing as mp -import torch.distributed as dist - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.gemini.chunk import search_chunk_configuration -from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup -from tests.components_to_test.registry import non_distributed_component_funcs - - -def init_1d_row_spec(model, pg: ProcessGroup): - tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - for n, p in model.named_parameters(): - if 'weight' in n and 'ln' not in n: - p.set_process_group(pg) - p.set_tensor_spec(*tensor_spec) - - -def exam_search_chunk_size(): - - world_size = torch.distributed.get_world_size() - pg_tp = ProcessGroup(tp_degree=world_size) - - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - # make sure torch_model and model has the same parameter values - with ColoInitContext(device=get_current_device()): - model = model_builder() - init_1d_row_spec(model, pg_tp) - config_dict = search_chunk_configuration(model, - search_range_mb=1, - search_interval_byte=16, - min_chunk_size_mb=0, - filter_exlarge_params=True) - - for key in config_dict: - chunk_size = config_dict[key]['chunk_size'] - if world_size == 1: - assert chunk_size == 31616 - else: - assert chunk_size == 1024 - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_search_chunk_size() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_search(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_search(4) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import search_chunk_configuration +from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs + + +def init_1d_row_spec(model, pg: ProcessGroup): + tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + for n, p in model.named_parameters(): + if 'weight' in n and 'ln' not in n: + p.set_process_group(pg) + p.set_tensor_spec(*tensor_spec) + + +def exam_search_chunk_size(): + + world_size = torch.distributed.get_world_size() + pg_tp = ProcessGroup(tp_degree=world_size) + + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + # make sure torch_model and model has the same parameter values + with ColoInitContext(device=get_current_device()): + model = model_builder() + init_1d_row_spec(model, pg_tp) + config_dict, _ = search_chunk_configuration(model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True) + + for key in config_dict: + chunk_size = config_dict[key]['chunk_size'] + if world_size == 1: + assert chunk_size == 31616 + else: + assert chunk_size == 1024 + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_search_chunk_size() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_search(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_search(4) diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index 69f46b900..ea2783fb8 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -1,110 +1,108 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext - -from functools import partial -from tests.test_tensor.common_utils import set_seed -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ZeroDDP -from colossalai.testing import parameterize -from colossalai.gemini.gemini_mgr import GeminiManager -from tests.test_tensor.common_utils import debug_print - -from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('keep_gathered', [True, False]) -def exam_state_dict(placement_policy, keep_gathered): - set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - torch_model = model_builder() - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - model.train() - - zero_dict = model.state_dict(only_rank_0=False) - torch_dict = torch_model.state_dict() - - for key, value in torch_dict.items(): - if key == 'model.lm_head.weight': - continue - assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('keep_gathered', [True, False]) -def exam_load_state_dict(placement_policy, keep_gathered): - set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - set_seed(451) - torch_model = model_builder() # get a different model - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered - - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - torch_dict = torch_model.state_dict() - model.load_state_dict(torch_dict, strict=False) - zero_dict = model.state_dict(only_rank_0=False) - - for key, value in torch_dict.items(): - if key == 'model.lm_head.weight': - continue - assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_state_dict() - exam_load_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_zero_ddp(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_zero_ddp(1) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +def exam_state_dict(placement_policy, keep_gathered): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + torch_model = model_builder() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + model.train() + + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +def exam_load_state_dict(placement_policy, keep_gathered): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + set_seed(451) + torch_model = model_builder() # get a different model + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + torch_dict = torch_model.state_dict() + model.load_state_dict(torch_dict, strict=False) + zero_dict = model.state_dict(only_rank_0=False) + + for key, value in torch_dict.items(): + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + exam_load_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_ddp(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_ddp(1) diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py index 9361c4b67..74761668a 100644 --- a/tests/test_gemini/update/test_zerooptim_state_dict.py +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -1,97 +1,95 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext - -from functools import partial -from tests.test_tensor.common_utils import set_seed -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ZeroDDP -from colossalai.zero import ZeroOptimizer -from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize -from colossalai.gemini.gemini_mgr import GeminiManager -from tests.test_tensor.common_utils import debug_print - -from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager - - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('keep_gathered', [True, False]) -def exam_zero_optim_state_dict(placement_policy, keep_gathered): - set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - - set_seed(451) - torch_model = model_builder() # get a different model - - world_size = torch.distributed.get_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered - - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - - optimizer = HybridAdam(model.parameters()) - optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 - - set_seed(dist.get_rank() * 3 + 128) - model.train() - for i, (input_ids, attn_mask) in enumerate(train_dataloader): - if i > 0: - break - optim.zero_grad() - logits = model(input_ids, attn_mask) - logits = logits.float() - loss = criterion(logits, input_ids) - optim.backward(loss) - optim.step() - - optim_state_dict = optim.state_dict() - optim.load_state_dict(optim_state_dict) - new_state = optim.state_dict()['state'] - org_state = optim_state_dict['state'] - - for k, v in org_state.items(): - w = new_state[k] - for n, m in v.items(): - if isinstance(m, torch.Tensor): - o = w[n] - if m.device != o.device: - o = o.to(m.device) - assert torch.equal(m, o) - else: - assert m == w[n] - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_zero_optim_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_zero_optim(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_zero_optim(1) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ZeroOptimizer +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import debug_print, set_seed + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +def exam_zero_optim_state_dict(placement_policy, keep_gathered): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + set_seed(451) + torch_model = model_builder() # get a different model + + world_size = torch.distributed.get_world_size() + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters()) + optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 + + set_seed(dist.get_rank() * 3 + 128) + model.train() + for i, (input_ids, attn_mask) in enumerate(train_dataloader): + if i > 0: + break + optim.zero_grad() + logits = model(input_ids, attn_mask) + logits = logits.float() + loss = criterion(logits, input_ids) + optim.backward(loss) + optim.step() + + optim_state_dict = optim.state_dict() + optim.load_state_dict(optim_state_dict) + new_state = optim.state_dict()['state'] + org_state = optim_state_dict['state'] + + for k, v in org_state.items(): + w = new_state[k] + for n, m in v.items(): + if isinstance(m, torch.Tensor): + o = w[n] + if m.device != o.device: + o = o.to(m.device) + assert torch.equal(m, o) + else: + assert m == w[n] + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_zero_optim_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_optim(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_optim(1) diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index 70cb837d8..ad5a83e57 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -1,23 +1,24 @@ +from functools import partial + import pytest -import colossalai import torch import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from functools import partial -from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal -from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.nn.parallel import ZeroDDP -from colossalai.nn.optimizer import HybridAdam -from colossalai.zero import ZeroOptimizer -from colossalai.testing import parameterize + +import colossalai from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ZeroOptimizer +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal from tests.test_tensor.model.test_gpt2 import init_megatron_spec @@ -88,7 +89,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None): tp_init_spec_func(model, pg) dp_world_size = pg.dp_world_size() - config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[dp_world_size]['chunk_size'] = 5000 config_dict[dp_world_size]['keep_gathered'] = False if placement_policy != 'cuda':