diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 1f1993f..9a8cc5c 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -81,6 +81,10 @@ hybrid_zero_optimizer = dict( reduce_bucket_size=512 * 1024 * 1024, # grad clipping clip_grad_norm=1.0, + # cuda memory balance for activation + cuda_memory_balance=False, + cuda_memory_balance_amount=1 * 1024, # MB + cuda_memory_balance_compensation={}, ) loss = dict( @@ -140,8 +144,8 @@ pipeline parallel (dict): tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( - zero1=8, - pipeline=dict(size=1, interleaved_overlap=True), + zero1=-1, + pipeline=dict(size=2, interleaved_overlap=True), sequence_parallel=False, ) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 8bdeccf..568a406 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -4,11 +4,13 @@ import math from functools import partial from itertools import product +from typing import List import torch import torch.distributed as dist from torch.optim import Optimizer +from internlm.core.communication import recv_obj_meta, send_obj_meta from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc from internlm.monitor import send_alert_message @@ -33,12 +35,34 @@ from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from .utils import compute_norm +from .utils import compute_norm, find_subset_with_target_sum inf = math.inf logger = get_logger(__file__) +def _find_tensors_with_target_memory(tensors: List[torch.Tensor], target: int) -> List[int]: + tensor_mems = [tensor.nelement() * tensor.element_size() for tensor in tensors] + approximate_thresholds = [0.01 * i for i in range(1, 100)] + result = None + + for approximate_threshold in approximate_thresholds: + result = find_subset_with_target_sum(tensor_mems, target * 1024 * 1024, approximate_threshold) + if result is not None: + break + + return result + + +def _flatten_and_sync_params(tensors: List[torch.Tensor]) -> torch.Tensor: + with torch.no_grad(): + flat_tensor = flatten(tensors) + flat_tensor = flat_tensor.data.cuda() + sync_param(flat_tensor=flat_tensor, tensor_list=tensors) + + return flat_tensor + + class BaseOptimizer(Optimizer): """ Base Optimizer. @@ -133,6 +157,47 @@ class HybridZeroOptimizer(BaseOptimizer): # self._overlap_communication = overlap_communication self._reduce_bucket_size = reduce_bucket_size + # Cuda memory balance + self._enable_memory_balance = zero_cfg.cuda_memory_balance + if self._enable_memory_balance: + assert gpc.get_world_size(ParallelMode.PIPELINE) > 0, "pipeline parallel size must > 0" + assert gpc.get_world_size(ParallelMode.PIPELINE) % 2 == 0, "pipeline parallel size must be even" + + _peer_local_rank = gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 + _self_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + self._memory_balance_role = gpc.get_local_rank(ParallelMode.PIPELINE) // ( + gpc.get_world_size(ParallelMode.PIPELINE) // 2 + ) # 0: sender, 1: receiver + self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank] + self._fp32_flat_proxy_param_of_current_rank = None + self._memory_balance_comm_handle = None + + compensation_conf = { + k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v + for k, v in zero_cfg.cuda_memory_balance_compensation.items() + } + _compensation_amount = compensation_conf.get(_self_local_rank, 0) - compensation_conf.get( + _peer_local_rank, 0 + ) + if self._memory_balance_role == 1: + _compensation_amount = -_compensation_amount + + # We balance the memory load caused by different activation quantities on different stages of the pipeline + # by having the latter half of the pipeline stages proxy a portion of the optimizer parameters from the + # former half. Typically, the first stage's activation occupies pp_size units of memory, decreasing in + # increments, and the last stage's activation occupies 1 unit of memory. The number of parameters to be + # proxied can be determined based on pp_rank. Since 1 set of optimizer parameters corresponds to 2 sets + # of optimizer states, the number of parameters to be proxied needs to be divided by 3. Additionally, the + # split parameters are in fp16, while the actual optimizer state parameters are in fp32, so they need to be + # divided by 2. + self._memory_balance_amount = ( + (zero_cfg.cuda_memory_balance_amount * abs(_peer_local_rank - _self_local_rank) + _compensation_amount) + / 2 + / 3 + / 2 + ) + # gradient scaler self.grad_scaler = DynamicGradScaler( initial_scale=initial_scale, @@ -172,6 +237,9 @@ class HybridZeroOptimizer(BaseOptimizer): # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): + # We only select the parameters to be proxied in the first parameter group. + _enable_memory_balance = group_id == 0 and self._enable_memory_balance + group_params = param_group["params"] # add the fp16 params to fp16_param_groups for bookkeeping @@ -182,14 +250,28 @@ class HybridZeroOptimizer(BaseOptimizer): self.param_group_no_params_ranks.append(no_params_ranks) self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks) + # split proxy parameters + if _enable_memory_balance and self._memory_balance_role == 0: + proxy_params_per_rank = [ + _find_tensors_with_target_memory(params, self._memory_balance_amount) for params in params_per_rank + ] + else: + proxy_params_per_rank = [None for _ in params_per_rank] + # store the mapping between param to rank each param should belong to only one rank - for rank, params in enumerate(params_per_rank): + for rank in range(self._zero_world_size): + params = params_per_rank[rank] + proxy_params = proxy_params_per_rank[rank] + # check whether any rank is not assigned params. - if len(params) != 0: - self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params) - for param in params: - setattr(param, "group_id", group_id) - self._param_store.set_param_to_rank(param, rank) + if len(params) == 0: + continue + + for param in params: + setattr(param, "group_id", group_id) + self._param_store.set_param_to_rank(param, rank) + + self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params, proxy_params) # move to cpu to make room to create the flat tensor for param in group_params: @@ -198,13 +280,28 @@ class HybridZeroOptimizer(BaseOptimizer): # flatten the reordered tensors for rank in range(self._zero_world_size): # No flat fp16 buffer is allocated if the process has no parameters. - if rank not in self.param_group_no_params_ranks[group_id]: - tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) - with torch.no_grad(): - flat_tensor = flatten(tensor_list) - flat_tensor = flat_tensor.data.cuda() - self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor) - sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) + if rank in self.param_group_no_params_ranks[group_id]: + continue + + _params = self._param_store.get_fp16_params_by_rank_group(rank, group_id, option="without_proxy") + _flat_tensor = _flatten_and_sync_params(_params) + self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, _flat_tensor) + + if _enable_memory_balance and self._memory_balance_role == 0: + _params = self._param_store.get_fp16_params_by_rank_group(rank, group_id, option="proxy_only") + _flat_tensor = _flatten_and_sync_params(_params) + self._param_store.add_flat_proxy_param_by_rank_group(rank, group_id, _flat_tensor) + + if _enable_memory_balance and self._memory_balance_role == 0: + flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(self._zero_local_rank, group_id) + + send_obj_meta(flat_proxy_param, next_rank=self._memory_balance_peer) + dist.send(flat_proxy_param, dst=self._memory_balance_peer) + elif _enable_memory_balance and self._memory_balance_role == 1: + flat_offload_shape = recv_obj_meta(prev_rank=self._memory_balance_peer) + + flat_proxy_param = torch.empty(flat_offload_shape, device=get_current_device(), dtype=self._dtype) + dist.recv(flat_proxy_param, src=self._memory_balance_peer) # create a copy of fp32 weights of the parameters for which this rank is responsible # No flat fp32 buffer is allocated if the process has no parameters. @@ -221,7 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer): # need to replace the params in the `params` field in the optimizer # so that when the optimizer calls step(), it only updates the tensors # managed by this data parallel rank - param_group["params"] = [fp32_flat_current_rank] + optim_params = [fp32_flat_current_rank] + + if _enable_memory_balance and self._memory_balance_role == 1: + flat_proxy_param = flat_proxy_param.to(device=device, dtype=fp32_flat_current_rank.dtype) + flat_proxy_param.requires_grad = True + optim_params.append(flat_proxy_param) + self._fp32_flat_proxy_param_of_current_rank = flat_proxy_param + + param_group["params"] = optim_params # set reduction state for param in self._fp16_param_groups[group_id]: @@ -438,23 +543,54 @@ class HybridZeroOptimizer(BaseOptimizer): self._param_store.reset_reduced_data_for_compute_norm() # accumulate gradient + proxy_gradinets = [] avg_gradients = self._grad_store._averaged_gradients + for group_id in range(self.num_param_groups): # the following operations are performed only on the rank to which parameters are assigned. - if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: - param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank, group_id) + if self._zero_local_rank in self.param_group_no_params_ranks[group_id]: + continue - if group_id not in avg_gradients: - avg_gradients[group_id] = [] + param_group = self._param_store.get_fp16_params_by_rank_group( + self._zero_local_rank, group_id, option="without_proxy" + ) - param_idx = 0 + if group_id not in avg_gradients: + avg_gradients[group_id] = [] + + param_idx = 0 + for param in param_group: + if param.grad is not None: + if len(avg_gradients[group_id]) == param_idx: + avg_gradients[group_id].append(param.grad) + else: + avg_gradients[group_id][param_idx].add_(param.grad) + param_idx += 1 + + if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0: + param_group = self._param_store.get_fp16_params_by_rank_group( + self._zero_local_rank, group_id, option="proxy_only" + ) for param in param_group: - if param.grad is not None: - if len(avg_gradients[group_id]) == param_idx: - avg_gradients[group_id].append(param.grad) - else: - avg_gradients[group_id][param_idx].add_(param.grad) - param_idx += 1 + assert param.grad is not None, "gradient of proxy parameter is None" + proxy_gradinets.append(param.grad) + + # send offload gradients to reciever + if self._enable_memory_balance and self._memory_balance_role == 0: + flat_proxy_grads = flatten(proxy_gradinets) + + self._memory_balance_comm_handle = dist.isend(flat_proxy_grads, self._memory_balance_peer) + # torch.cuda.synchronize() + elif self._enable_memory_balance and self._enable_memory_balance == 1: + _shape = self._fp32_flat_proxy_param_of_current_rank.shape + _device = self._fp32_flat_proxy_param_of_current_rank.device + flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype) + + self._memory_balance_comm_handle = dist.irecv(flat_proxy_gradient, self._memory_balance_peer) + # torch.cuda.synchronize() + self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to( + dtype=self._fp32_flat_proxy_param_of_current_rank.dtype + ) # the gradients needed are stored in the avg_gradients buffer # thus, can clear this @@ -633,9 +769,27 @@ class HybridZeroOptimizer(BaseOptimizer): # For those ranks that are not assigned parameters, we just wait for other ranks # to send them updated their own parameters. if self.has_params: + if self._enable_memory_balance: + self._memory_balance_comm_handle.wait() + self.optim.step() # release the fp32 grad release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) + if self._enable_memory_balance and self._memory_balance_role == 1: + self._fp32_flat_proxy_param_of_current_rank.grad = None + + # receive proxy params + if self._enable_memory_balance and self._memory_balance_role == 0: + flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group( + rank=self._zero_local_rank, group_id=0 + ) + dist.recv(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE)) + # torch.cuda.synchronize() + elif self._enable_memory_balance and self._memory_balance_role == 1: + flat_proxy_param = self._fp32_flat_proxy_param_of_current_rank.to(dtype=self._dtype) + dist.send(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE)) + # torch.cuda.synchronize() + # update fp16 partition updated by the current rank for group_id in range(len(self._fp16_param_groups)): if self.param_group_has_params[group_id]: @@ -677,6 +831,17 @@ class HybridZeroOptimizer(BaseOptimizer): else: handles.append(handle) + if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0: + flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(rank, group_id) + handle = dist.broadcast( + flat_proxy_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True + ) + + if self._overlap_sync_param: + self._param_bcast_sync_handler.add_bcast_handle(rank, handle) + else: + handles.append(handle) + for handle in handles: handle.wait() diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 05a44d2..d1aae25 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -146,7 +146,9 @@ class ParameterStore(BaseStore): # param partitioning data structures self._fp16_param_to_rank = dict() self._rank_groupid_to_fp16_param_list = dict() - self._rank_group_id_to_flat_fp16_param = dict() + self._rank_groupid_to_flat_fp16_param = dict() + self._rank_groupid_to_proxy_param_indexs = dict() + self._rank_groupid_to_flat_proxy_param = dict() # param reduction data structures self._is_param_reduced = dict() @@ -192,26 +194,55 @@ class ParameterStore(BaseStore): tensor_rank = self._fp16_param_to_rank[tensor] return tensor_rank == self._local_rank - def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: + def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list, indexs_to_proxy) -> None: if rank not in self._rank_groupid_to_fp16_param_list: self._rank_groupid_to_fp16_param_list[rank] = dict() + self._rank_groupid_to_proxy_param_indexs[rank] = dict() if group_id not in self._rank_groupid_to_fp16_param_list[rank]: self._rank_groupid_to_fp16_param_list[rank][group_id] = [] + self._rank_groupid_to_proxy_param_indexs[rank][group_id] = [] + + if indexs_to_proxy is not None: + self._rank_groupid_to_proxy_param_indexs[rank][group_id].extend(indexs_to_proxy) self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list) - def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]: - return self._rank_groupid_to_fp16_param_list[rank][group_id] + def get_fp16_params_by_rank_group(self, rank, group_id, option: str = "all") -> List[Tensor]: + res = [] + + if option == "without_proxy": + for idx, param in enumerate(self._rank_groupid_to_fp16_param_list[rank][group_id]): + if idx in self._rank_groupid_to_proxy_param_indexs[rank][group_id]: + continue + res.append(param) + elif option == "proxy_only": + for idx, param in enumerate(self._rank_groupid_to_fp16_param_list[rank][group_id]): + if idx not in self._rank_groupid_to_proxy_param_indexs[rank][group_id]: + continue + res.append(param) + else: + res = self._rank_groupid_to_fp16_param_list[rank][group_id] + + return res def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None: - if rank not in self._rank_group_id_to_flat_fp16_param: - self._rank_group_id_to_flat_fp16_param[rank] = dict() + if rank not in self._rank_groupid_to_flat_fp16_param: + self._rank_groupid_to_flat_fp16_param[rank] = dict() - self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor + self._rank_groupid_to_flat_fp16_param[rank][group_id] = tensor def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor: - return self._rank_group_id_to_flat_fp16_param[rank][group_id] + return self._rank_groupid_to_flat_fp16_param[rank][group_id] + + def add_flat_proxy_param_by_rank_group(self, rank, group_id, tensor) -> None: + if rank not in self._rank_groupid_to_flat_proxy_param: + self._rank_groupid_to_flat_proxy_param[rank] = dict() + + self._rank_groupid_to_flat_proxy_param[rank][group_id] = tensor + + def get_flat_proxy_param_by_rank_group(self, rank, group_id) -> Tensor: + return self._rank_groupid_to_flat_proxy_param[rank][group_id] def is_param_reduced(self, tensor): return self._is_param_reduced[tensor] diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 38e4560..896d279 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -5,7 +5,7 @@ import math from abc import ABC, abstractmethod from collections import OrderedDict from functools import partial -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist @@ -317,6 +317,27 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no return total_norm +def find_subset_with_target_sum(nums: List[int], target: int, approximate_threshold: float = 0.0) -> List[int]: + indexs = [] + + def _inner_helper(start: int, tmpTarget: int, part_idxs: List[int]): + if len(indexs) > 0: + return + + if len(part_idxs) > 0 and ( + tmpTarget >= -target * approximate_threshold and tmpTarget <= target * approximate_threshold + ): + indexs.append(part_idxs) + elif tmpTarget > 0: + for i in range(start, len(nums)): + num = nums[i] + _inner_helper(start + 1, tmpTarget - num, part_idxs + [i]) + + _inner_helper(start=0, tmpTarget=target, part_idxs=[]) + + return indexs[0] if len(indexs) > 0 else None + + class BaseGradScaler(ABC): """A base class for the gradient scaler. diff --git a/train.py b/train.py index de7cc7c..0cd39c9 100644 --- a/train.py +++ b/train.py @@ -177,9 +177,10 @@ def main(args): memory_profiler = SimpleMemoryProfiler( model, optimizer.optim, - log_folder=f"memory_trace/rank{gpc.get_global_rank()}_" - + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" - + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", + log_folder=f"memory_trace/rank{gpc.get_global_rank()}" + + f"_dp{gpc.get_local_rank(ParallelMode.DATA)}" + + f"_tp{gpc.get_local_rank(ParallelMode.TENSOR)}" + + f"_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}", ) else: memory_profiler = None