add pipeline memory balance

pull/306/head
mwiacx 2023-09-13 16:55:42 +08:00
parent 42851be36b
commit 0b1c6c6704
5 changed files with 262 additions and 40 deletions

View File

@ -81,6 +81,10 @@ hybrid_zero_optimizer = dict(
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
# cuda memory balance for activation
cuda_memory_balance=False,
cuda_memory_balance_amount=1 * 1024, # MB
cuda_memory_balance_compensation={},
)
loss = dict(
@ -140,8 +144,8 @@ pipeline parallel (dict):
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=8,
pipeline=dict(size=1, interleaved_overlap=True),
zero1=-1,
pipeline=dict(size=2, interleaved_overlap=True),
sequence_parallel=False,
)

View File

@ -4,11 +4,13 @@
import math
from functools import partial
from itertools import product
from typing import List
import torch
import torch.distributed as dist
from torch.optim import Optimizer
from internlm.core.communication import recv_obj_meta, send_obj_meta
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.monitor import send_alert_message
@ -33,12 +35,34 @@ from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from .utils import compute_norm
from .utils import compute_norm, find_subset_with_target_sum
inf = math.inf
logger = get_logger(__file__)
def _find_tensors_with_target_memory(tensors: List[torch.Tensor], target: int) -> List[int]:
tensor_mems = [tensor.nelement() * tensor.element_size() for tensor in tensors]
approximate_thresholds = [0.01 * i for i in range(1, 100)]
result = None
for approximate_threshold in approximate_thresholds:
result = find_subset_with_target_sum(tensor_mems, target * 1024 * 1024, approximate_threshold)
if result is not None:
break
return result
def _flatten_and_sync_params(tensors: List[torch.Tensor]) -> torch.Tensor:
with torch.no_grad():
flat_tensor = flatten(tensors)
flat_tensor = flat_tensor.data.cuda()
sync_param(flat_tensor=flat_tensor, tensor_list=tensors)
return flat_tensor
class BaseOptimizer(Optimizer):
"""
Base Optimizer.
@ -133,6 +157,47 @@ class HybridZeroOptimizer(BaseOptimizer):
# self._overlap_communication = overlap_communication
self._reduce_bucket_size = reduce_bucket_size
# Cuda memory balance
self._enable_memory_balance = zero_cfg.cuda_memory_balance
if self._enable_memory_balance:
assert gpc.get_world_size(ParallelMode.PIPELINE) > 0, "pipeline parallel size must > 0"
assert gpc.get_world_size(ParallelMode.PIPELINE) % 2 == 0, "pipeline parallel size must be even"
_peer_local_rank = gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
_self_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
self._memory_balance_role = gpc.get_local_rank(ParallelMode.PIPELINE) // (
gpc.get_world_size(ParallelMode.PIPELINE) // 2
) # 0: sender, 1: receiver
self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank]
self._fp32_flat_proxy_param_of_current_rank = None
self._memory_balance_comm_handle = None
compensation_conf = {
k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v
for k, v in zero_cfg.cuda_memory_balance_compensation.items()
}
_compensation_amount = compensation_conf.get(_self_local_rank, 0) - compensation_conf.get(
_peer_local_rank, 0
)
if self._memory_balance_role == 1:
_compensation_amount = -_compensation_amount
# We balance the memory load caused by different activation quantities on different stages of the pipeline
# by having the latter half of the pipeline stages proxy a portion of the optimizer parameters from the
# former half. Typically, the first stage's activation occupies pp_size units of memory, decreasing in
# increments, and the last stage's activation occupies 1 unit of memory. The number of parameters to be
# proxied can be determined based on pp_rank. Since 1 set of optimizer parameters corresponds to 2 sets
# of optimizer states, the number of parameters to be proxied needs to be divided by 3. Additionally, the
# split parameters are in fp16, while the actual optimizer state parameters are in fp32, so they need to be
# divided by 2.
self._memory_balance_amount = (
(zero_cfg.cuda_memory_balance_amount * abs(_peer_local_rank - _self_local_rank) + _compensation_amount)
/ 2
/ 3
/ 2
)
# gradient scaler
self.grad_scaler = DynamicGradScaler(
initial_scale=initial_scale,
@ -172,6 +237,9 @@ class HybridZeroOptimizer(BaseOptimizer):
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups):
# We only select the parameters to be proxied in the first parameter group.
_enable_memory_balance = group_id == 0 and self._enable_memory_balance
group_params = param_group["params"]
# add the fp16 params to fp16_param_groups for bookkeeping
@ -182,14 +250,28 @@ class HybridZeroOptimizer(BaseOptimizer):
self.param_group_no_params_ranks.append(no_params_ranks)
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
# split proxy parameters
if _enable_memory_balance and self._memory_balance_role == 0:
proxy_params_per_rank = [
_find_tensors_with_target_memory(params, self._memory_balance_amount) for params in params_per_rank
]
else:
proxy_params_per_rank = [None for _ in params_per_rank]
# store the mapping between param to rank each param should belong to only one rank
for rank, params in enumerate(params_per_rank):
for rank in range(self._zero_world_size):
params = params_per_rank[rank]
proxy_params = proxy_params_per_rank[rank]
# check whether any rank is not assigned params.
if len(params) != 0:
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
for param in params:
setattr(param, "group_id", group_id)
self._param_store.set_param_to_rank(param, rank)
if len(params) == 0:
continue
for param in params:
setattr(param, "group_id", group_id)
self._param_store.set_param_to_rank(param, rank)
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params, proxy_params)
# move to cpu to make room to create the flat tensor
for param in group_params:
@ -198,13 +280,28 @@ class HybridZeroOptimizer(BaseOptimizer):
# flatten the reordered tensors
for rank in range(self._zero_world_size):
# No flat fp16 buffer is allocated if the process has no parameters.
if rank not in self.param_group_no_params_ranks[group_id]:
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
with torch.no_grad():
flat_tensor = flatten(tensor_list)
flat_tensor = flat_tensor.data.cuda()
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
if rank in self.param_group_no_params_ranks[group_id]:
continue
_params = self._param_store.get_fp16_params_by_rank_group(rank, group_id, option="without_proxy")
_flat_tensor = _flatten_and_sync_params(_params)
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, _flat_tensor)
if _enable_memory_balance and self._memory_balance_role == 0:
_params = self._param_store.get_fp16_params_by_rank_group(rank, group_id, option="proxy_only")
_flat_tensor = _flatten_and_sync_params(_params)
self._param_store.add_flat_proxy_param_by_rank_group(rank, group_id, _flat_tensor)
if _enable_memory_balance and self._memory_balance_role == 0:
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(self._zero_local_rank, group_id)
send_obj_meta(flat_proxy_param, next_rank=self._memory_balance_peer)
dist.send(flat_proxy_param, dst=self._memory_balance_peer)
elif _enable_memory_balance and self._memory_balance_role == 1:
flat_offload_shape = recv_obj_meta(prev_rank=self._memory_balance_peer)
flat_proxy_param = torch.empty(flat_offload_shape, device=get_current_device(), dtype=self._dtype)
dist.recv(flat_proxy_param, src=self._memory_balance_peer)
# create a copy of fp32 weights of the parameters for which this rank is responsible
# No flat fp32 buffer is allocated if the process has no parameters.
@ -221,7 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer):
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group["params"] = [fp32_flat_current_rank]
optim_params = [fp32_flat_current_rank]
if _enable_memory_balance and self._memory_balance_role == 1:
flat_proxy_param = flat_proxy_param.to(device=device, dtype=fp32_flat_current_rank.dtype)
flat_proxy_param.requires_grad = True
optim_params.append(flat_proxy_param)
self._fp32_flat_proxy_param_of_current_rank = flat_proxy_param
param_group["params"] = optim_params
# set reduction state
for param in self._fp16_param_groups[group_id]:
@ -438,23 +543,54 @@ class HybridZeroOptimizer(BaseOptimizer):
self._param_store.reset_reduced_data_for_compute_norm()
# accumulate gradient
proxy_gradinets = []
avg_gradients = self._grad_store._averaged_gradients
for group_id in range(self.num_param_groups):
# the following operations are performed only on the rank to which parameters are assigned.
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank, group_id)
if self._zero_local_rank in self.param_group_no_params_ranks[group_id]:
continue
if group_id not in avg_gradients:
avg_gradients[group_id] = []
param_group = self._param_store.get_fp16_params_by_rank_group(
self._zero_local_rank, group_id, option="without_proxy"
)
param_idx = 0
if group_id not in avg_gradients:
avg_gradients[group_id] = []
param_idx = 0
for param in param_group:
if param.grad is not None:
if len(avg_gradients[group_id]) == param_idx:
avg_gradients[group_id].append(param.grad)
else:
avg_gradients[group_id][param_idx].add_(param.grad)
param_idx += 1
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
param_group = self._param_store.get_fp16_params_by_rank_group(
self._zero_local_rank, group_id, option="proxy_only"
)
for param in param_group:
if param.grad is not None:
if len(avg_gradients[group_id]) == param_idx:
avg_gradients[group_id].append(param.grad)
else:
avg_gradients[group_id][param_idx].add_(param.grad)
param_idx += 1
assert param.grad is not None, "gradient of proxy parameter is None"
proxy_gradinets.append(param.grad)
# send offload gradients to reciever
if self._enable_memory_balance and self._memory_balance_role == 0:
flat_proxy_grads = flatten(proxy_gradinets)
self._memory_balance_comm_handle = dist.isend(flat_proxy_grads, self._memory_balance_peer)
# torch.cuda.synchronize()
elif self._enable_memory_balance and self._enable_memory_balance == 1:
_shape = self._fp32_flat_proxy_param_of_current_rank.shape
_device = self._fp32_flat_proxy_param_of_current_rank.device
flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype)
self._memory_balance_comm_handle = dist.irecv(flat_proxy_gradient, self._memory_balance_peer)
# torch.cuda.synchronize()
self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to(
dtype=self._fp32_flat_proxy_param_of_current_rank.dtype
)
# the gradients needed are stored in the avg_gradients buffer
# thus, can clear this
@ -633,9 +769,27 @@ class HybridZeroOptimizer(BaseOptimizer):
# For those ranks that are not assigned parameters, we just wait for other ranks
# to send them updated their own parameters.
if self.has_params:
if self._enable_memory_balance:
self._memory_balance_comm_handle.wait()
self.optim.step()
# release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
if self._enable_memory_balance and self._memory_balance_role == 1:
self._fp32_flat_proxy_param_of_current_rank.grad = None
# receive proxy params
if self._enable_memory_balance and self._memory_balance_role == 0:
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(
rank=self._zero_local_rank, group_id=0
)
dist.recv(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
# torch.cuda.synchronize()
elif self._enable_memory_balance and self._memory_balance_role == 1:
flat_proxy_param = self._fp32_flat_proxy_param_of_current_rank.to(dtype=self._dtype)
dist.send(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
# torch.cuda.synchronize()
# update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)):
if self.param_group_has_params[group_id]:
@ -677,6 +831,17 @@ class HybridZeroOptimizer(BaseOptimizer):
else:
handles.append(handle)
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(rank, group_id)
handle = dist.broadcast(
flat_proxy_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
)
if self._overlap_sync_param:
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else:
handles.append(handle)
for handle in handles:
handle.wait()

View File

@ -146,7 +146,9 @@ class ParameterStore(BaseStore):
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()
self._rank_group_id_to_flat_fp16_param = dict()
self._rank_groupid_to_flat_fp16_param = dict()
self._rank_groupid_to_proxy_param_indexs = dict()
self._rank_groupid_to_flat_proxy_param = dict()
# param reduction data structures
self._is_param_reduced = dict()
@ -192,26 +194,55 @@ class ParameterStore(BaseStore):
tensor_rank = self._fp16_param_to_rank[tensor]
return tensor_rank == self._local_rank
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list, indexs_to_proxy) -> None:
if rank not in self._rank_groupid_to_fp16_param_list:
self._rank_groupid_to_fp16_param_list[rank] = dict()
self._rank_groupid_to_proxy_param_indexs[rank] = dict()
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
self._rank_groupid_to_proxy_param_indexs[rank][group_id] = []
if indexs_to_proxy is not None:
self._rank_groupid_to_proxy_param_indexs[rank][group_id].extend(indexs_to_proxy)
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list)
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_groupid_to_fp16_param_list[rank][group_id]
def get_fp16_params_by_rank_group(self, rank, group_id, option: str = "all") -> List[Tensor]:
res = []
if option == "without_proxy":
for idx, param in enumerate(self._rank_groupid_to_fp16_param_list[rank][group_id]):
if idx in self._rank_groupid_to_proxy_param_indexs[rank][group_id]:
continue
res.append(param)
elif option == "proxy_only":
for idx, param in enumerate(self._rank_groupid_to_fp16_param_list[rank][group_id]):
if idx not in self._rank_groupid_to_proxy_param_indexs[rank][group_id]:
continue
res.append(param)
else:
res = self._rank_groupid_to_fp16_param_list[rank][group_id]
return res
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_fp16_param:
self._rank_group_id_to_flat_fp16_param[rank] = dict()
if rank not in self._rank_groupid_to_flat_fp16_param:
self._rank_groupid_to_flat_fp16_param[rank] = dict()
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
self._rank_groupid_to_flat_fp16_param[rank][group_id] = tensor
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
return self._rank_groupid_to_flat_fp16_param[rank][group_id]
def add_flat_proxy_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_groupid_to_flat_proxy_param:
self._rank_groupid_to_flat_proxy_param[rank] = dict()
self._rank_groupid_to_flat_proxy_param[rank][group_id] = tensor
def get_flat_proxy_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_groupid_to_flat_proxy_param[rank][group_id]
def is_param_reduced(self, tensor):
return self._is_param_reduced[tensor]

View File

@ -5,7 +5,7 @@ import math
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import partial
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as dist
@ -317,6 +317,27 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
return total_norm
def find_subset_with_target_sum(nums: List[int], target: int, approximate_threshold: float = 0.0) -> List[int]:
indexs = []
def _inner_helper(start: int, tmpTarget: int, part_idxs: List[int]):
if len(indexs) > 0:
return
if len(part_idxs) > 0 and (
tmpTarget >= -target * approximate_threshold and tmpTarget <= target * approximate_threshold
):
indexs.append(part_idxs)
elif tmpTarget > 0:
for i in range(start, len(nums)):
num = nums[i]
_inner_helper(start + 1, tmpTarget - num, part_idxs + [i])
_inner_helper(start=0, tmpTarget=target, part_idxs=[])
return indexs[0] if len(indexs) > 0 else None
class BaseGradScaler(ABC):
"""A base class for the gradient scaler.

View File

@ -177,9 +177,10 @@ def main(args):
memory_profiler = SimpleMemoryProfiler(
model,
optimizer.optim,
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
log_folder=f"memory_trace/rank{gpc.get_global_rank()}"
+ f"_dp{gpc.get_local_rank(ParallelMode.DATA)}"
+ f"_tp{gpc.get_local_rank(ParallelMode.TENSOR)}"
+ f"_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}",
)
else:
memory_profiler = None