diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index b36b88455..0879f5fd2 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -55,7 +55,6 @@ class MoeContext(metaclass=SingletonMeta): return self.has_setup def setup(self, seed: int, use_kernel_optim: bool = True): - assert not self.is_initialized, "MoE distributed context shouldn't be set up again" _check_sanity() assert torch.cuda.is_available(), "MoE requires to enable CUDA first" @@ -93,8 +92,8 @@ class MoeContext(metaclass=SingletonMeta): gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number"\ - " is not a multiple of ep size or vice versa." + assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ + " is not a multiple of ep size or vice versa." # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, # there are multiple experts in each GPU and each GPU has different experts diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/gemini/chunk/chunk.py index e02d14055..648d48ec5 100644 --- a/colossalai/gemini/chunk/chunk.py +++ b/colossalai/gemini/chunk/chunk.py @@ -51,6 +51,8 @@ def alloc_storage(tensor: torch.Tensor) -> None: class Chunk: + _total_number = 0 + def __init__(self, chunk_size: int, process_group: ColoProcessGroup, @@ -73,6 +75,8 @@ class Chunk: keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory """ + self.count_id = Chunk._total_number + Chunk._total_number += 1 self.chunk_size = chunk_size self.utilized_size = 0 diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index 2d75dcce5..4a2474a63 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -28,7 +28,7 @@ class ChunkManager: self.chunk_groups: Dict[str, Deque] = dict() self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() self.accessed_chunks: Set[Chunk] = set() - self.lazy_release_tensors: List[torch.Tensor] = list() + self.accessed_mem: int = 0 self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None: @@ -91,9 +91,8 @@ class ChunkManager: self.__sub_memroy_usage(chunk.memory_usage) if chunk.device_type == 'cpu': chunk.shard_move(get_current_device()) - chunk.access_chunk() + self.__add_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) - self.accessed_chunks.add(chunk) def release_chunk(self, chunk: Chunk) -> None: """Scatter the chunk in CUDA. @@ -102,9 +101,8 @@ class ChunkManager: return if chunk.can_release: self.__sub_memroy_usage(chunk.memory_usage) - chunk.release_chunk() + self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) - self.accessed_chunks.remove(chunk) def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: """Move the shard of the chunk to the target device. @@ -128,6 +126,7 @@ class ChunkManager: return False self.__sub_memroy_usage(chunk.memory_usage) chunk.reduce() + self.__sub_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) return True @@ -151,31 +150,15 @@ class ChunkManager: """ return self.tensor_chunk_map[tensor] - def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None: + def get_cuda_movable_chunks(self) -> List[Chunk]: """ - Add tensors to the buffer for lazy release. - - Args: - tensors (List[torch.Tensor]): the tensors to be released lazily + Get all chunks that can be moved. """ - self.lazy_release_tensors.extend(tensors) - - def exec_lazy_release(self) -> None: - """ - Execute release for tensors added to the lazy release buffer. - """ - - for chunk in self.get_chunks(self.lazy_release_tensors): - self.release_chunk(chunk) - self.lazy_release_tensors.clear() - - def get_cuda_movable_chunks(self, group_type: str) -> List[Chunk]: chunk_list = [] - for group_name in self.chunk_groups: - if group_type in group_name: - for chunk in self.chunk_groups[group_name]: - if chunk.device_type == 'cuda' and chunk.can_move: - chunk_list.append(chunk) + for chunk in self.accessed_chunks: + if chunk.can_release: + chunk_list.append(chunk) + chunk_list.sort(key=lambda x: x.count_id) return chunk_list def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: @@ -235,3 +218,13 @@ class ChunkManager: def __add_memory_usage(self, usage: Dict[str, int]): for k, v in usage.items(): self.total_mem[k] += v + + def __add_accessed_chunk(self, chunk: Chunk): + chunk.access_chunk() + self.accessed_chunks.add(chunk) + self.accessed_mem += chunk.chunk_mem + + def __sub_accessed_chunk(self, chunk: Chunk): + chunk.release_chunk() + self.accessed_chunks.remove(chunk) + self.accessed_mem -= chunk.chunk_mem diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 0bdddd9a7..6d6b7425c 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -56,14 +56,14 @@ class GeminiManager: self._evict_time = 0 self._comp_cuda_demand_time = 0 - def adjust_layout(self, chunks: Tuple[Chunk, ...], group_type: str) -> None: + def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: """ Adjust the layout of statefuil tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE start = time() self._record_chunks_order(chunks) - cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_type) + cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) self._layout_time += time() - start vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list, @@ -78,7 +78,7 @@ class GeminiManager: self._h2d_volume += cuda_demand @functools.lru_cache(maxsize=None) - def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_type: str): + def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]): start = time() cuda_demand = 0 for chunk in chunks: @@ -93,7 +93,7 @@ class GeminiManager: raise RuntimeError self._comp_cuda_demand_time += time() - start - can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks(group_type) + can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() return cuda_demand, can_evict_chunks def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index 1a7e172ed..ab1988b11 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -36,8 +36,9 @@ class CPUPlacementPolicy(PlacementPolicy): volume = 0 start = time() for chunk in can_evict_chunks: + self.chunk_manager.release_chunk(chunk) self.chunk_manager.move_chunk(chunk, torch.device('cpu')) - volume += chunk.shard_mem + volume += chunk.chunk_mem return volume, time() - start @@ -116,8 +117,9 @@ class AutoPlacementPolicy(PlacementPolicy): if freed_cuda_model_data >= to_free_cuda_model_data: break + self.chunk_manager.release_chunk(chunk) self.chunk_manager.move_chunk(chunk, torch.device('cpu')) - freed_cuda_model_data += chunk.shard_mem + freed_cuda_model_data += chunk.chunk_mem if freed_cuda_model_data < to_free_cuda_model_data: raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}") @@ -147,11 +149,74 @@ class AutoPlacementPolicy(PlacementPolicy): AutoPlacementPolicy._steady_cuda_cap_ratio = ratio +class ConstPlacementPolicy(PlacementPolicy): + + need_mem_stats: bool = False + _accessed_memory_boundary = 512 * 1024**2 + + def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: + super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, + can_evict_chunks: List[Chunk], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: Optional[List[Tuple[Chunk, ...]]] = None, + compute_idx: int = 0, + **kwargs) -> Tuple[int, float]: + """ + See the docstrings in the class `AutoPlacementPolicy`. + """ + start = time() + used_accessed_memory = self.chunk_manager.accessed_mem + avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory + freed_accessed_memory = 0 + + if avail_accessed_memory < cuda_demand: + to_free_memory = cuda_demand - avail_accessed_memory + to_free_chunks = can_evict_chunks + + if not warmup: + # sort all chunks + to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list)) + + for chunk in to_free_chunks: + if freed_accessed_memory >= to_free_memory: + break + + self.chunk_manager.release_chunk(chunk) + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + freed_accessed_memory += chunk.chunk_mem + + if freed_accessed_memory < to_free_memory: + raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " + f"Need {to_free_memory}, freed {freed_accessed_memory}") + return freed_accessed_memory, time() - start + + @staticmethod + @functools.lru_cache(maxsize=None) + def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list: + next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks} + for i in range(len(compute_list) - 1, compute_idx, -1): + for chunk in compute_list[i]: + if chunk in next_compute_idx: + next_compute_idx[chunk] = i + next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) + return [t for (t, idx) in next_compute_idx] + + @staticmethod + def set_const_memory_boundary(cuda_memory_mb: int) -> None: + boundary = int(cuda_memory_mb * 1024**2) + assert boundary > 0 + ConstPlacementPolicy._accessed_memory_boundary = boundary + + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { 'cpu': CPUPlacementPolicy, 'cuda': CUDAPlacementPolicy, - 'auto': AutoPlacementPolicy + 'auto': AutoPlacementPolicy, + 'const': ConstPlacementPolicy } @staticmethod diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index daa4cb15e..5bce81708 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -224,14 +224,16 @@ class ZeroDDP(ColoDDP): # TODO: get param order and filter unused params for p in module.parameters(): assert isinstance(p, ColoParameter) + if getattr(p, '_ddp_to_ignore', False): - p.data = p.half() + p.data = p.data.half() continue - dp_world_size = p.process_group.dp_world_size() - fp32_data = p.float().data - p.data = p.half() + fp32_data = p.data.float() fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) + p.data = p.data.half() + + dp_world_size = p.process_group.dp_world_size() self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory) self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory) self.fp32_params.append(fp32_p) @@ -253,7 +255,6 @@ class ZeroDDP(ColoDDP): self.gemini_manager.pre_iter() with ParamOpHookManager.use_hooks(self.param_op_hook): outputs = self.module(*args, **kwargs) - self.chunk_manager.exec_lazy_release() if self.force_outputs_fp32: return _cast_float(outputs, torch.float) return outputs @@ -265,7 +266,7 @@ class ZeroDDP(ColoDDP): p.grad = None def _post_backward(self): - self.chunk_manager.exec_lazy_release() + assert self.chunk_manager.accessed_mem == 0 self._setup_grads_ptr() self._logger.debug( f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index a54740221..3824d27f6 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -33,7 +33,10 @@ def ColoModulize(module): class ColoInitContext(InsertPostInitMethodToModuleSubClasses): - def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): + def __init__(self, + lazy_memory_allocate: bool = False, + device: torch.device = torch.device('cpu'), + dtype: torch.dtype = torch.float): """ Args: lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False. @@ -42,6 +45,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): super().__init__() self._lazy_memory_allocate = lazy_memory_allocate self._device = device + self._dtype = dtype self._register_colo_modules() @@ -87,7 +91,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): # detaching tensor is necessary for optimizers. requires_grad = param.requires_grad # TODO(jiaruifang) we initialize a Default PG memory - colo_param = ColoParameter(param.to(self._device), requires_grad=requires_grad) + colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype), + requires_grad=requires_grad) # add mapping record replaced_tensors[param] = colo_param delattr(submodule, param_name) diff --git a/colossalai/zero/utils/zero_hook_v2.py b/colossalai/zero/utils/zero_hook_v2.py index 3f3472f0e..584a0fe37 100644 --- a/colossalai/zero/utils/zero_hook_v2.py +++ b/colossalai/zero/utils/zero_hook_v2.py @@ -26,9 +26,8 @@ class ZeROHookV2(ParamOpHook): chunks = self._chunk_manager.get_chunks(params) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) - self._chunk_manager.exec_lazy_release() self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks, 'fp16_param') + self._gemini_manager.adjust_layout(chunks) for chunk in chunks: self._chunk_manager.access_chunk(chunk) self._gemini_manager.sample_model_data() @@ -38,7 +37,6 @@ class ZeROHookV2(ParamOpHook): for p in params: tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD self._chunk_manager.trans_tensor_state(p, tensor_state) - self._chunk_manager.add_lazy_release_tensors(params) def pre_forward(self, params: List[torch.Tensor]) -> None: self.pre_op(params) diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index 6bd25c0be..4b9694c0d 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -11,18 +11,14 @@ from functools import partial from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP +from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager from colossalai.nn.parallel import ZeroDDP -from colossalai.nn.optimizer import HybridAdam -from colossalai.zero import ZeroOptimizer from colossalai.testing import parameterize from colossalai.amp import convert_to_apex_amp from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor +from colossalai.tensor import ProcessGroup from tests.test_tensor.common_utils import debug_print -from time import time -from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager - def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): chunk_manager = model.chunk_manager @@ -44,7 +40,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): return logits -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) def exam_gpt_fwd_bwd(placement_policy): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable('gpt2') diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index cefda045d..3c82258a5 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -48,7 +48,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): return logits -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) def exam_gpt_fwd_bwd(placement_policy): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable('gpt2') diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index 86e39097e..69f46b900 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -12,7 +12,6 @@ from functools import partial from tests.test_tensor.common_utils import set_seed from tests.components_to_test.registry import non_distributed_component_funcs from colossalai.nn.parallel import ZeroDDP -from colossalai.zero import ZeroOptimizer from colossalai.testing import parameterize from colossalai.gemini.gemini_mgr import GeminiManager from tests.test_tensor.common_utils import debug_print