From e5aa8333e423c5e8d3b10c4ee7c37838d786a94a Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 12 Dec 2022 16:57:22 +0800 Subject: [PATCH] [NFC] update chunk manager API (#2119) --- colossalai/gemini/chunk/chunk.py | 2 +- colossalai/gemini/chunk/manager.py | 30 ++-- colossalai/gemini/chunk/search_utils.py | 2 +- colossalai/nn/parallel/data_parallel.py | 20 +-- tests/test_gemini/update/test_chunk_mgrv2.py | 142 ++++++++++--------- 5 files changed, 100 insertions(+), 96 deletions(-) diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/gemini/chunk/chunk.py index 5bd948f57..a0b274197 100644 --- a/colossalai/gemini/chunk/chunk.py +++ b/colossalai/gemini/chunk/chunk.py @@ -294,7 +294,7 @@ class Chunk: self.chunk_temp = None self.__scatter() - # always gathered chunk does not have shard + # gathered chunk never have shard attribute if self.keep_gathered: return diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index ac73105a0..07fb6c48b 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -17,13 +17,13 @@ class ChunkManager: init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. """ - def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None: + def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: self.device = init_device or get_current_device() - self.size_config: Dict[int, int] = dict() + self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration for k, v in self.kwargs_config.items(): - self.size_config[k] = v.pop('chunk_size') + self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size') v['init_device'] = self.device self.chunk_groups: Dict[str, Deque] = dict() @@ -32,26 +32,28 @@ class ChunkManager: self.accessed_mem: int = 0 self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} - def append_tensor(self, - tensor: ColoTensor, - group_type: str, - config_key: int, - cpu_offload: bool = False, - pin_memory: bool = False) -> None: - """Append a tensor to a chunk. + def register_tensor(self, + tensor: ColoTensor, + group_type: str, + config_key: int, + cpu_offload: bool = False, + pin_memory: bool = False) -> None: + """ + Register a tensor to the chunk manager. + Then, the tensor should be accessed by `get_chunks`. Args: tensor: the tensor appended to the chunk - group_type: the data type of the group - config_key: the key of the group's name, usually the size of the dp world + group_type: the data type of the group. + config_key: the key of the group's name, the size of the dp world cpu_offload: if True, the chunk will be closed on CPU pin_memory: whether the chunk is pinned in the cpu memory """ assert tensor not in self.tensor_chunk_map assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" - assert config_key in self.size_config + assert config_key in self.dp_degree_chunk_size_dict - chunk_size = self.size_config[config_key] + chunk_size = self.dp_degree_chunk_size_dict[config_key] chunk_kwargs = self.kwargs_config[config_key] group_name = "{}_{}".format(group_type, config_key) chunk_group = self.__get_chunk_group(group_name) diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py index f55d87fc2..b92a8b158 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/gemini/chunk/search_utils.py @@ -83,7 +83,7 @@ def search_chunk_configuration( filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. Returns: - Tuple[Dict, int]: chunk config and its memory chunk waste in byte. + Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte. """ param_order = OrderedParamGenerator() diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 75736f603..14d85489a 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -228,16 +228,16 @@ class ZeroDDP(ColoDDP): fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) p.data = p.data.half() dp_world_size = p.process_group.dp_world_size() - self.chunk_manager.append_tensor(tensor=p, - group_type='fp16_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - self.chunk_manager.append_tensor(tensor=fp32_p, - group_type='fp32_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) + self.chunk_manager.register_tensor(tensor=p, + group_type='fp16_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + self.chunk_manager.register_tensor(tensor=fp32_p, + group_type='fp32_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device self.chunk_manager.close_all_groups() diff --git a/tests/test_gemini/update/test_chunk_mgrv2.py b/tests/test_gemini/update/test_chunk_mgrv2.py index fa7a9b1b5..7d192fc63 100644 --- a/tests/test_gemini/update/test_chunk_mgrv2.py +++ b/tests/test_gemini/update/test_chunk_mgrv2.py @@ -1,70 +1,72 @@ -import torch -import colossalai -import pytest -import torch.multiprocessing as mp -from functools import partial -from colossalai.gemini.chunk import ChunkManager -from colossalai.testing import rerun_if_address_is_in_use, parameterize -from colossalai.utils import free_port -from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec -from tests.test_tensor.common_utils import debug_print - -CUDA_MEM_0 = {False: 512, True: 1024} -CUDA_MEM_1 = {False: 0, True: 1024} -CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} - - -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) -def exam_chunk_memory(keep_gathered, pin_memory): - pg = ProcessGroup() - - debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) - - params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] - config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} - - chunk_manager = ChunkManager(config) - assert chunk_manager.total_mem['cpu'] == 0 - assert chunk_manager.total_mem['cuda'] == 0 - - for p in params: - chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory) - chunk_manager.close_all_groups() - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] - - chunks = chunk_manager.get_chunks(params) - - for chunk in chunks: - chunk_manager.access_chunk(chunk) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] - - for chunk in chunks: - chunk_manager.release_chunk(chunk) - - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] - - for chunk in chunks: - chunk_manager.move_chunk(chunk, torch.device('cpu')) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_chunk_memory() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@rerun_if_address_is_in_use() -def test_chunk_manager(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_chunk_manager(2) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.chunk import ChunkManager +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_tensor.common_utils import debug_print + +CUDA_MEM_0 = {False: 512, True: 1024} +CUDA_MEM_1 = {False: 0, True: 1024} +CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} + + +@parameterize('keep_gathered', [True, False]) +@parameterize('pin_memory', [True, False]) +def exam_chunk_memory(keep_gathered, pin_memory): + pg = ProcessGroup() + + debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) + + params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] + config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} + + chunk_manager = ChunkManager(config) + assert chunk_manager.total_mem['cpu'] == 0 + assert chunk_manager.total_mem['cuda'] == 0 + + for p in params: + chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory) + chunk_manager.close_all_groups() + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + + chunks = chunk_manager.get_chunks(params) + + for chunk in chunks: + chunk_manager.access_chunk(chunk) + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] + + for chunk in chunks: + chunk_manager.release_chunk(chunk) + + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + + for chunk in chunks: + chunk_manager.move_chunk(chunk, torch.device('cpu')) + assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] + assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_chunk_memory() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_chunk_manager(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_manager(2)