mirror of https://github.com/InternLM/InternLM
feat(hybrid_zero_optim): reduce optimizer peek memory
parent
10f01c4e08
commit
fda0a96def
|
@ -65,6 +65,18 @@ def _flatten_and_sync_params(tensors: List[torch.Tensor]) -> torch.Tensor:
|
|||
return flat_tensor
|
||||
|
||||
|
||||
def _create_fp32_param_copy(parameters: List[torch.Tensor], device: torch.device) -> List[torch.Tensor]:
|
||||
fp32_params = []
|
||||
|
||||
# create fp32 parameter copy
|
||||
for param in parameters:
|
||||
fp32_param = param.data.to(device, dtype=torch.float32)
|
||||
fp32_param.requires_grad = True
|
||||
fp32_params.append(fp32_param)
|
||||
|
||||
return fp32_params
|
||||
|
||||
|
||||
class BaseOptimizer(Optimizer):
|
||||
"""
|
||||
Base Optimizer.
|
||||
|
@ -140,6 +152,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
super().__init__(optim=optimizer)
|
||||
|
||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||
self._dtype_memory = self.optim.param_groups[0]["params"][0].element_size()
|
||||
self._cpu_offload = cpu_offload
|
||||
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
|
@ -154,6 +167,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
self._fp16_param_groups = dict()
|
||||
self._fp32_orig_param_groups_of_current_rank = dict()
|
||||
self._fp32_flat_param_groups_of_current_rank = dict()
|
||||
|
||||
# communication params
|
||||
|
@ -175,6 +189,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
) # 0: sender, 1: receiver
|
||||
self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank]
|
||||
self._fp32_flat_proxy_param_of_current_rank = None
|
||||
self._fp32_proxy_param_groups_of_current_rank = None
|
||||
|
||||
self._proxy_param_gradients_of_current_rank = None
|
||||
|
||||
compensation_conf = {
|
||||
k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v
|
||||
|
@ -196,10 +213,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# divided by 2.
|
||||
self._memory_balance_amount = (
|
||||
(zero_cfg.cuda_memory_balance_amount * abs(_peer_local_rank - _self_local_rank) + _compensation_amount)
|
||||
/ 2
|
||||
/ 3
|
||||
/ 2
|
||||
/ 2 # total -> need to move
|
||||
/ 3 # optim param, exp_avg, exp_avg_sq -> optim param
|
||||
)
|
||||
# convert optimizer parameter dtype to model parameter dtype.
|
||||
self._memory_balance_amount /= 4 / self._dtype_memory
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
|
@ -294,39 +312,53 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
if _enable_memory_balance and self._memory_balance_role == 0:
|
||||
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(self._zero_local_rank, group_id)
|
||||
proxy_params = self._param_store.get_fp16_params_by_rank_group(
|
||||
self._zero_local_rank, group_id, option="proxy_only"
|
||||
)
|
||||
|
||||
send_obj_meta(flat_proxy_param, next_rank=self._memory_balance_peer)
|
||||
send_obj_meta(proxy_params, next_rank=self._memory_balance_peer)
|
||||
dist.send(flat_proxy_param, dst=self._memory_balance_peer)
|
||||
elif _enable_memory_balance and self._memory_balance_role == 1:
|
||||
flat_offload_shape = recv_obj_meta(prev_rank=self._memory_balance_peer)
|
||||
flat_proxy_shape = recv_obj_meta(prev_rank=self._memory_balance_peer)
|
||||
proxy_param_shapes = recv_obj_meta(prev_rank=self._memory_balance_peer)
|
||||
|
||||
flat_proxy_param = torch.empty(flat_offload_shape, device=get_current_device(), dtype=self._dtype)
|
||||
# fix recv_obj_meta result when length of proxy_params is 1.
|
||||
if isinstance(proxy_param_shapes, torch.Size):
|
||||
proxy_param_shapes = [proxy_param_shapes]
|
||||
|
||||
flat_proxy_param = torch.empty(flat_proxy_shape, device=get_current_device(), dtype=self._dtype)
|
||||
dist.recv(flat_proxy_param, src=self._memory_balance_peer)
|
||||
|
||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||
# No flat fp32 buffer is allocated if the process has no parameters.
|
||||
if self.param_group_has_params[group_id]:
|
||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||
self._zero_local_rank, group_id
|
||||
)
|
||||
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
||||
device = "cpu" if self._cpu_offload else get_current_device()
|
||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
||||
fp32_flat_current_rank.requires_grad = True
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
|
||||
|
||||
# need to replace the params in the `params` field in the optimizer
|
||||
# so that when the optimizer calls step(), it only updates the tensors
|
||||
# managed by this data parallel rank
|
||||
optim_params = [fp32_flat_current_rank]
|
||||
# parameters belong to current rank
|
||||
fp16_params = self._param_store.get_fp16_params_by_rank_group(
|
||||
self._zero_local_rank, group_id, option="without_proxy"
|
||||
)
|
||||
fp32_params = _create_fp32_param_copy(fp16_params, device)
|
||||
flat_fp32_param = _flatten_and_sync_params(fp32_params)
|
||||
self._fp32_orig_param_groups_of_current_rank[group_id] = fp32_params
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id] = flat_fp32_param
|
||||
|
||||
# proxy parameters
|
||||
fp32_proxy_params = []
|
||||
if _enable_memory_balance and self._memory_balance_role == 1:
|
||||
flat_proxy_param = flat_proxy_param.to(device=device, dtype=fp32_flat_current_rank.dtype)
|
||||
flat_proxy_param.requires_grad = True
|
||||
optim_params.append(flat_proxy_param)
|
||||
# create empty tensor for fp32 proxy paramters
|
||||
for _shape in proxy_param_shapes:
|
||||
fp32_proxy_param = torch.empty(_shape, dtype=torch.float32, device=device)
|
||||
fp32_proxy_param.requires_grad = True
|
||||
fp32_proxy_params.append(fp32_proxy_param)
|
||||
# sync with received flat fp32 proxy parameter
|
||||
flat_proxy_param = flat_proxy_param.to(device=device, dtype=torch.float32)
|
||||
sync_param(flat_proxy_param, fp32_proxy_params)
|
||||
self._fp32_proxy_param_groups_of_current_rank = fp32_proxy_params
|
||||
|
||||
self._fp32_flat_proxy_param_of_current_rank = flat_proxy_param
|
||||
|
||||
param_group["params"] = optim_params
|
||||
param_group["params"] = fp32_params + fp32_proxy_params
|
||||
|
||||
# set reduction state
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
|
@ -543,7 +575,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._param_store.reset_reduced_data_for_compute_norm()
|
||||
|
||||
# accumulate gradient
|
||||
proxy_gradinets = []
|
||||
avg_gradients = self._grad_store._averaged_gradients
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
|
@ -568,29 +599,13 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
param_idx += 1
|
||||
|
||||
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
|
||||
self._proxy_param_gradients_of_current_rank = []
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(
|
||||
self._zero_local_rank, group_id, option="proxy_only"
|
||||
)
|
||||
for param in param_group:
|
||||
assert param.grad is not None, "gradient of proxy parameter is None"
|
||||
proxy_gradinets.append(param.grad)
|
||||
|
||||
# send offload gradients to reciever
|
||||
if self._enable_memory_balance and self._memory_balance_role == 0:
|
||||
flat_proxy_grads = flatten(proxy_gradinets)
|
||||
|
||||
dist.send(flat_proxy_grads, self._memory_balance_peer)
|
||||
# torch.cuda.synchronize()
|
||||
elif self._enable_memory_balance and self._enable_memory_balance == 1:
|
||||
_shape = self._fp32_flat_proxy_param_of_current_rank.shape
|
||||
_device = self._fp32_flat_proxy_param_of_current_rank.device
|
||||
flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype)
|
||||
|
||||
dist.recv(flat_proxy_gradient, self._memory_balance_peer)
|
||||
# torch.cuda.synchronize()
|
||||
self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to(
|
||||
dtype=self._fp32_flat_proxy_param_of_current_rank.dtype
|
||||
)
|
||||
self._proxy_param_gradients_of_current_rank.append(param.grad)
|
||||
|
||||
# the gradients needed are stored in the avg_gradients buffer
|
||||
# thus, can clear this
|
||||
|
@ -745,32 +760,25 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
return False, norms
|
||||
|
||||
# copy the grad of fp16 param to fp32 param
|
||||
single_grad_partition_groups = []
|
||||
grads_partition_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if not self.param_group_has_params[group_id]:
|
||||
continue
|
||||
|
||||
# create flat gradient for the flat fp32 params
|
||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
with torch.no_grad():
|
||||
flat_fp16_avg_grads = flatten(gradients)
|
||||
self._grad_store.reset_average_gradients_by_group(group_id)
|
||||
gradients = None # release cuda memory
|
||||
fp32_params = self._fp32_orig_param_groups_of_current_rank[group_id]
|
||||
grads_partition_groups.append([])
|
||||
|
||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
||||
flat_fp16_avg_grads = None # release cuda memory
|
||||
for idx, grad in enumerate(gradients):
|
||||
fp32_grad = grad.data.float()
|
||||
fp32_params[idx].grad = fp32_grad
|
||||
grads_partition_groups[group_id].append(fp32_grad)
|
||||
|
||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
||||
assert (
|
||||
param_shape == flat_fp32_avg_grads.shape
|
||||
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
||||
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
|
||||
grads_partition_groups[group_id].extend(self._proxy_param_gradients_of_current_rank)
|
||||
|
||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||
gradients = None
|
||||
|
||||
# unscale and clip grads
|
||||
# get the global norm
|
||||
|
@ -781,39 +789,50 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
if gpc.config.model.dtype is not torch.float32:
|
||||
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
|
||||
if len(grads_partition_groups) != 0 and self._clip_grad_norm > 0:
|
||||
self._unscale_and_clip_grads(
|
||||
single_grad_partition_groups,
|
||||
grads_partition_groups,
|
||||
list(global_norm_groups.values()),
|
||||
loss_scale,
|
||||
)
|
||||
grads_partition_groups = None
|
||||
|
||||
# update the parameters
|
||||
timer("step").start()
|
||||
|
||||
# send and receive proxy gradients
|
||||
if self._enable_memory_balance and self._memory_balance_role == 0:
|
||||
for gradient in self._proxy_param_gradients_of_current_rank:
|
||||
dist.send(gradient, dst=self._memory_balance_peer)
|
||||
self._proxy_param_gradients_of_current_rank = None
|
||||
elif self._enable_memory_balance and self._enable_memory_balance == 1:
|
||||
for proxy_param in self._fp32_proxy_param_groups_of_current_rank:
|
||||
proxy_gradient = torch.empty(proxy_param.shape, device=proxy_param.device, dtype=self._dtype)
|
||||
dist.recv(proxy_gradient, self._memory_balance_peer)
|
||||
proxy_param.grad = proxy_gradient.to(dtype=proxy_param.dtype)
|
||||
|
||||
# For those ranks that are not assigned parameters, we just wait for other ranks
|
||||
# to send them updated their own parameters.
|
||||
if self.has_params:
|
||||
self.optim.step()
|
||||
# release the fp32 grad
|
||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||
if self._enable_memory_balance and self._memory_balance_role == 1:
|
||||
self._fp32_flat_proxy_param_of_current_rank.grad = None
|
||||
for group_id in range(self.num_param_groups):
|
||||
release_param_grad(self._fp32_orig_param_groups_of_current_rank[group_id])
|
||||
|
||||
if self._enable_memory_balance and self._memory_balance_role == 1:
|
||||
release_param_grad(self._fp32_proxy_param_groups_of_current_rank)
|
||||
|
||||
# receive proxy params
|
||||
if self._enable_memory_balance and self._memory_balance_role == 0:
|
||||
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(
|
||||
rank=self._zero_local_rank, group_id=0
|
||||
)
|
||||
dist.recv(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
|
||||
# torch.cuda.synchronize()
|
||||
dist.recv(flat_proxy_param, self._memory_balance_peer)
|
||||
elif self._enable_memory_balance and self._memory_balance_role == 1:
|
||||
flat_proxy_param = self._fp32_flat_proxy_param_of_current_rank.to(dtype=self._dtype)
|
||||
dist.send(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
|
||||
# torch.cuda.synchronize()
|
||||
dist.send(flat_proxy_param, self._memory_balance_peer)
|
||||
|
||||
# update fp16 partition updated by the current rank
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
for group_id in range(self.num_param_groups):
|
||||
if self.param_group_has_params[group_id]:
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=self._zero_local_rank, group_id=group_id
|
||||
|
@ -894,7 +913,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale):
|
||||
def _unscale_and_clip_grads(self, gradients_groups, total_norm_groups, loss_scale):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale_groups = []
|
||||
|
||||
|
@ -906,7 +925,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if clip > 1.0:
|
||||
combined_scale_groups[group_id] = clip * loss_scale
|
||||
|
||||
for group_id, grad in enumerate(grad_groups_flat):
|
||||
for group_id, grads in enumerate(gradients_groups):
|
||||
for grad in grads:
|
||||
grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||
|
||||
def clip_grad_norm(self, model, max_norm):
|
||||
|
|
Loading…
Reference in New Issue