diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index 07fb6c48b..e73c59b25 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -140,6 +140,14 @@ class ChunkManager: self.__add_memory_usage(chunk.memory_usage) return True + def fake_release_chunk(self, chunk: Chunk) -> None: + """Release gathered chunk in a fake mode. + This function is used for keep-gathered chunk in the inference mode. + """ + assert chunk.keep_gathered + assert chunk.tensor_state_cnter[TensorState.HOLD] == chunk.num_tensors + self.__sub_accessed_chunk(chunk) + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: """ Copy data to the chunk. diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index a313da59b..8e0192c71 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -257,8 +257,11 @@ class ZeroDDP(ColoDDP): access_list = list(self.chunk_manager.accessed_chunks) # we need to scatter all accessed chunks and move them to their original places for chunk in access_list: - assert chunk.can_release - self.chunk_manager.release_chunk(chunk) + if chunk.keep_gathered: + self.chunk_manager.fake_release_chunk(chunk) + else: + assert chunk.can_release + self.chunk_manager.release_chunk(chunk) first_param = next(iter(chunk.tensors_info)) self.chunk_manager.move_chunk(chunk, self.grads_device[first_param]) assert self.chunk_manager.accessed_mem == 0 diff --git a/tests/test_gemini/update/test_inference.py b/tests/test_gemini/update/test_inference.py index 443155865..b057448ad 100644 --- a/tests/test_gemini/update/test_inference.py +++ b/tests/test_gemini/update/test_inference.py @@ -1,4 +1,5 @@ from functools import partial +from typing import Callable import pytest import torch @@ -13,7 +14,7 @@ from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chu from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP +from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device @@ -36,9 +37,35 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) +def multi_chunk_init(model: torch.nn.Module, placement_policy: str): + world_size = dist.get_world_size() + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + return model + + +def single_chunk_init(model: torch.nn.Module, placement_policy: str): + gemini_config = dict( + device=get_current_device(), + placement_policy=placement_policy, + pin_memory=True, + ) + model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config) + return model + + @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('model_name', ['gpt2']) -def exam_inference(placement_policy, model_name: str): +@parameterize('model_init_func', [single_chunk_init, multi_chunk_init]) +def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable): set_seed(19360226) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -56,18 +83,7 @@ def exam_inference(placement_policy, model_name: str): for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) - world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - chunk_manager = ChunkManager(config_dict, init_device=init_device) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) - + model = model_init_func(model, placement_policy) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)