From 7d14b473f0013e62f2e2d9ba60ac8e39d62eb5fb Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 15 Jun 2022 15:05:19 +0800 Subject: [PATCH] [gemini] gemini mgr supports "cpu" placement policy (#1118) * update gemini mgr * update chunk * add docstr * polish placement policy * update test chunk * update test zero * polish unit test * remove useless unit test --- colossalai/gemini/gemini_mgr.py | 8 +- colossalai/gemini/placement_policy.py | 7 +- colossalai/nn/parallel/data_parallel.py | 2 + colossalai/tensor/chunk.py | 122 ++++++++++++++++++------ tests/test_tensor/test_chunk.py | 1 + tests/test_tensor/test_zero.py | 82 ---------------- tests/test_tensor/test_zero_optim.py | 31 ++++-- 7 files changed, 124 insertions(+), 129 deletions(-) delete mode 100644 tests/test_tensor/test_zero.py diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 6b5d23252..481761c37 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -1,4 +1,4 @@ -import functools +import torch from .memory_tracer.memstats_collector import MemStatsCollectorV2 from typing import List, Optional, Tuple from time import time @@ -15,8 +15,6 @@ class GeminiManager: """ def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: - # TODO: remove assert - assert placement_policy == 'cuda', 'placement_policy can only be "cuda" now' assert placement_policy in PlacementPolicyFactory.get_polocy_names() policy_cls = PlacementPolicyFactory.create(placement_policy) self._chunk_manager = chunk_manager @@ -111,3 +109,7 @@ class GeminiManager: @property def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats + + @staticmethod + def get_default_device(policy_name: str) -> torch.device: + return PlacementPolicyFactory.get_default_device(policy_name) diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index 84d356315..28b841c8d 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -34,10 +34,11 @@ class CPUPlacementPolicy(PlacementPolicy): def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int: volume = 0 + start = time() for chunk in can_evict_chunks: - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False) volume += chunk.mem - return volume, 0 + return volume, time() - start class CUDAPlacementPolicy(PlacementPolicy): @@ -115,7 +116,7 @@ class AutoPlacementPolicy(PlacementPolicy): if freed_cuda_model_data >= to_free_cuda_model_data: break freed_cuda_model_data += chunk.mem - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False) if freed_cuda_model_data < to_free_cuda_model_data: raise RuntimeError( f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 38091cd1f..823c355f4 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -100,6 +100,8 @@ class ColoDDPV2(ColoDDP): self.fp32_params = [] self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = {} + self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True) + self.chunk_manager.create_group('fp32_param') # TODO: get param order and filter unused params for p in module.parameters(): assert p.dtype == torch.half diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index 24243bf42..b3cb07328 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -36,8 +36,21 @@ class ChunkFullError(Exception): pass -class Chunk: +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + +class Chunk: """ A chunk is a contiguous memory space which contains multiple tensors. @@ -46,26 +59,37 @@ class Chunk: src_rank (int): the process which owns the chunk dtype (torch.dtype): the data type of the chunk init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU. + force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False. """ def __init__(self, chunk_size: int, src_rank: int, dtype: torch.dtype, - init_device: Optional[torch.device] = None) -> None: + init_device: Optional[torch.device] = None, + force_data_on_cuda: bool = False) -> None: self.size = chunk_size self.utilized_size = 0 self.src_rank = src_rank self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank] self.dtype = dtype - self.device = init_device or get_current_device() - self.data = torch.empty(chunk_size, dtype=dtype, device=self.device) + device = init_device or get_current_device() + if force_data_on_cuda: + self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device()) + self._cpu_data = torch.empty(chunk_size, dtype=dtype) + if device.type == 'cuda': + free_storage(self._cpu_data) + else: + free_storage(self.data) + else: + self.data = torch.empty(chunk_size, dtype=dtype, device=device) + self._cpu_data = None # we only keep the chunk in full in the process by which the tensor is owned if not self.is_src_rank: - self.data.storage().resize_(0) - + free_storage(self._payload) + # each tensor is associated with a TensorInfo to track meta info self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} self.mem = self.size * self.data.element_size() @@ -83,16 +107,16 @@ class Chunk: # raise exception when the chunk size is exceeded if new_utilized_size > self.size: raise ChunkFullError - + # set tensor state tensor_state = TensorState.FREE # if the process owns the rank, then copy the tensor to its chunk buffer # otherwise set its storage size to 0 to reduce memory consumption if self.is_src_rank: - self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1)) + self._payload[self.utilized_size:new_utilized_size].copy_(tensor.view(-1)) tensor_state = TensorState.HOLD - tensor.data = self.data[self.utilized_size:new_utilized_size].view(tensor.shape) + tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape) else: tensor.storage().resize_(0) self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) @@ -103,12 +127,12 @@ class Chunk: Release the memory space on processes which do not own the chunk. """ if not self.is_src_rank: - self.data.storage().resize_(0) + free_storage(self._payload) self._update_tensors_state(TensorState.FREE) def _update_tensors_ptr(self) -> None: for tensor, tensor_info in self.tensors_info.items(): - tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape) + tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): for tensor_info in self.tensors_info.values(): @@ -122,8 +146,8 @@ class Chunk: # recover the chunk on non-owner processes # and broadcast the chunk from the source to all processes if not self.is_src_rank: - self.data.storage().resize_(self.size) - self.data.data = self.data.to(get_current_device()) + alloc_storage(self._payload) + self.move_device(get_current_device(), update_ptr=False) dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) # update tensor meta info @@ -131,15 +155,32 @@ class Chunk: if not self.is_src_rank: self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) - def move_device(self, device: torch.device) -> None: + def move_device(self, device: torch.device, update_ptr: bool = True) -> None: """ Move the chunk to a target device. Args: device (torch.device): the target device for data movement. """ - self.data.data = self.data.to(device) - self._update_tensors_ptr() + if self._payload.device == device: + return + if self._cpu_data is None: + self.data.data = self.data.to(device) + else: + if device.type == 'cuda': + # cpu -> cuda + src = self._cpu_data + dest = self.data + else: + # cuda -> cpu + src = self.data + dest = self._cpu_data + alloc_storage(dest) + dest.copy_(src) + free_storage(src) + + if update_ptr: + self._update_tensors_ptr() def reduce(self, is_all_reduce: bool = False) -> None: """ @@ -148,7 +189,7 @@ class Chunk: Args: is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false. """ - self.data.data = self.data.to(get_current_device()) + self.move_device(get_current_device(), update_ptr=False) if is_all_reduce: dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) else: @@ -187,8 +228,8 @@ class Chunk: data_slice (torch.Tensor): the tensor to be copied to the chunk """ tensor_info = self.tensors_info[tensor] - self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1)) - tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape) + self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1)) + tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) @property def can_release(self) -> bool: @@ -225,7 +266,7 @@ class Chunk: """ Check whether the chunk is empty. """ - return self.data.storage().size() == 0 + return is_storage_empty(self._payload) def __repr__(self) -> str: return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}' @@ -235,8 +276,8 @@ class Chunk: """ Check if the chunk has inf or nan values. """ - return torch.isinf(self.data[:self.utilized_size]).any().item() or \ - torch.isnan(self.data[:self.utilized_size]).any().item() + return torch.isinf(self._payload[:self.utilized_size]).any().item() or \ + torch.isnan(self._payload[:self.utilized_size]).any().item() def copy_(self, dest_chunk: 'Chunk'): """ @@ -246,7 +287,7 @@ class Chunk: assert not dest_chunk.is_empty assert self.size == dest_chunk.size assert self.utilized_size == dest_chunk.utilized_size - self.data.copy_(dest_chunk.data) + self._payload.copy_(dest_chunk._payload) self._update_tensors_ptr() @property @@ -254,7 +295,7 @@ class Chunk: """ Get the device type of the chunk. """ - return self.data.device.type + return self._payload.device.type def __hash__(self) -> int: return hash(id(self)) @@ -265,6 +306,12 @@ class Chunk: def get_tensors(self) -> List[torch.Tensor]: return list(self.tensors_info.keys()) + @property + def _payload(self) -> torch.Tensor: + if self._cpu_data is None or is_storage_empty(self._cpu_data): + return self.data + return self._cpu_data + class ChunkManager: """ @@ -285,6 +332,7 @@ class ChunkManager: self.enable_distributed_storage = enable_distributed_storage self.device = init_device or get_current_device() self.chunk_groups: Dict[str, Deque[Chunk]] = {} + self.groups_force_data_on_cuda: Dict[str, bool] = {} self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {} self.accessed_chunks: Set[Chunk] = set() self.lazy_release_tensors: List[torch.Tensor] = [] @@ -292,6 +340,17 @@ class ChunkManager: self.rank_load: Dict[str, torch.Tensor] = {} self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} + def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None: + """Create a chunk group. + + Args: + group_name (str): group name + force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False. + """ + assert group_name not in self.chunk_groups + self.chunk_groups[group_name] = deque() + self.groups_force_data_on_cuda[group_name] = force_data_on_cuda + def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None: """ Append a tensor to a chunk. @@ -304,19 +363,20 @@ class ChunkManager: if self.chunk_size is not None and tensor.numel() > self.chunk_size: raise ValueError( f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})') - if group_name not in self.chunk_groups: - self.chunk_groups[group_name] = deque() - try: # append the tensor to the last chunk self.chunk_groups[group_name][-1].append(tensor) except (IndexError, ChunkFullError): - # the except statement will be triggered when there is no chunk or + # the except statement will be triggered when there is no chunk or # the last chunk in the chunk group is full # this will create a new chunk and allocate this chunk to its corresponding process chunk_size = self.chunk_size or tensor.numel() src_rank = self._get_next_src_rank(group_name) - chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device) + chunk = Chunk(chunk_size, + src_rank, + tensor.dtype, + self.device, + force_data_on_cuda=self.groups_force_data_on_cuda[group_name]) if self.enable_distributed_storage and self.chunk_size is None: self.rank_load[group_name][src_rank] += chunk_size @@ -387,7 +447,7 @@ class ChunkManager: # update the memory consumption after releasing self.total_mem[chunk.device_type] -= chunk.mem - def move_chunk(self, chunk: Chunk, device: torch.device) -> None: + def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None: """ Move the chunk to the target device. @@ -399,7 +459,7 @@ class ChunkManager: return if chunk.can_move_device and not chunk.is_empty: self.total_mem[chunk.device_type] -= chunk.mem - chunk.move_device(device) + chunk.move_device(device, update_ptr=update_ptr) self.total_mem[chunk.device_type] += chunk.mem def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: diff --git a/tests/test_tensor/test_chunk.py b/tests/test_tensor/test_chunk.py index f1d508d83..243c03941 100644 --- a/tests/test_tensor/test_chunk.py +++ b/tests/test_tensor/test_chunk.py @@ -44,6 +44,7 @@ def run_chunk_zero(use_chunk, use_zero): params = [torch.rand(8, 8) for _ in range(3)] chunk_size = 128 if use_chunk else None chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) + chunk_manager.create_group('param') assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cuda'] == 0 for p in params: diff --git a/tests/test_tensor/test_zero.py b/tests/test_tensor/test_zero.py deleted file mode 100644 index 62beff7da..000000000 --- a/tests/test_tensor/test_zero.py +++ /dev/null @@ -1,82 +0,0 @@ -import pytest -import colossalai -from colossalai.context.parallel_mode import ParallelMode -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ChunkManager -from colossalai.core import global_context as gpc -from functools import partial -from _utils import tensor_equal, set_seed -from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.nn.parallel import ColoDDPV2 -from colossalai.testing import parameterize -from colossalai.gemini.gemini_mgr import GeminiManager - - -def check_param_equal(model, torch_model): - for p, torch_p in zip(model.parameters(), torch_model.parameters()): - if p.storage().size() > 0: - assert tensor_equal(torch_p, p.float()), f'{torch_p} vs {p}' - - -def check_grad_equal(model, torch_model): - for p, torch_p in zip(model.parameters(), torch_model.parameters()): - if p.grad is not None: - assert tensor_equal(torch_p.grad, p.grad.float()) - - -@parameterize('use_chunk', [False, True]) -@parameterize('use_zero', [False, True]) -def run_gpt(use_chunk, use_zero): - set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder(checkpoint=True) - model = model.cuda() - torch_model = model_builder().cuda() - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p) - model = model.half() - chunk_size = 38 * 1024**2 if use_chunk else None - chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) - gemini_manager = GeminiManager('cuda', chunk_manager) - model = ColoDDPV2(model, gemini_manager) - torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA)) - print(chunk_manager) - check_param_equal(model, torch_model) - model.train() - torch_model.train() - set_seed(gpc.get_local_rank(ParallelMode.DATA)) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): - logits = model(input_ids, attn_mask) - torch_logits = torch_model(input_ids, attn_mask) - assert tensor_equal(torch_logits, logits.float()) - loss = criterion(logits, input_ids) - torch_loss = criterion(torch_logits, input_ids) - model.backward(loss) - torch_loss.backward() - check_grad_equal(model, torch_model) - break - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_gpt() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_gpt(4) diff --git a/tests/test_tensor/test_zero_optim.py b/tests/test_tensor/test_zero_optim.py index 7df9e110d..a86735815 100644 --- a/tests/test_tensor/test_zero_optim.py +++ b/tests/test_tensor/test_zero_optim.py @@ -25,22 +25,28 @@ def check_param_equal(model, torch_model): for p, torch_p in zip(model.parameters(), torch_model.parameters()): if p.storage().size() > 0: assert p.dtype == torch.half - assert tensor_equal(torch_p, p), f'{torch_p} vs {p}' + assert tensor_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}' -def run_step(model, criterion, optimizer, input_ids, attn_mask): +def check_grad_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + if p.grad is not None: + assert tensor_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad) + + +def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): optimizer.zero_grad() logits = model(input_ids, attn_mask) logits = logits.float() loss = criterion(logits, input_ids) optimizer.backward(loss) - optimizer.step() return logits @parameterize('use_chunk', [False, True]) @parameterize('use_zero', [False, True]) -def run_gpt(use_chunk, use_zero): +@parameterize('placement_policy', ['cuda', 'cpu']) +def run_gpt(use_chunk, use_zero, placement_policy): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -52,9 +58,11 @@ def run_gpt(use_chunk, use_zero): for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p) - chunk_size = 38 * 1024**2 if use_chunk else None - chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) - gemini_manager = GeminiManager('cuda', chunk_manager) + chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None + chunk_manager = ChunkManager(chunk_size, + enable_distributed_storage=use_zero, + init_device=GeminiManager.get_default_device(placement_policy)) + gemini_manager = GeminiManager(placement_policy, chunk_manager) model = ColoDDPV2(model, gemini_manager) optim = HybridAdam(model.parameters(), lr=1e-3) optim = ZeroOptimizer(optim, model, initial_scale=32) @@ -64,7 +72,7 @@ def run_gpt(use_chunk, use_zero): torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA)) - # print(chunk_manager) + print(chunk_manager) check_param_equal(model, torch_model) model.train() torch_model.train() @@ -72,9 +80,12 @@ def run_gpt(use_chunk, use_zero): for i, (input_ids, attn_mask) in enumerate(train_dataloader): if i > 2: break - logits = run_step(model, criterion, optim, input_ids, attn_mask) - torch_logits = run_step(torch_model, criterion, torch_optim, input_ids, attn_mask) + logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) assert tensor_equal(logits, torch_logits) + check_grad_equal(model, torch_model) + optim.step() + torch_optim.step() check_param_equal(model, torch_model)