mirror of https://github.com/InternLM/InternLM
feat(hybrid_zero_optim): reduce optimizer peek memory
parent
10f01c4e08
commit
fda0a96def
|
@ -65,6 +65,18 @@ def _flatten_and_sync_params(tensors: List[torch.Tensor]) -> torch.Tensor:
|
||||||
return flat_tensor
|
return flat_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _create_fp32_param_copy(parameters: List[torch.Tensor], device: torch.device) -> List[torch.Tensor]:
|
||||||
|
fp32_params = []
|
||||||
|
|
||||||
|
# create fp32 parameter copy
|
||||||
|
for param in parameters:
|
||||||
|
fp32_param = param.data.to(device, dtype=torch.float32)
|
||||||
|
fp32_param.requires_grad = True
|
||||||
|
fp32_params.append(fp32_param)
|
||||||
|
|
||||||
|
return fp32_params
|
||||||
|
|
||||||
|
|
||||||
class BaseOptimizer(Optimizer):
|
class BaseOptimizer(Optimizer):
|
||||||
"""
|
"""
|
||||||
Base Optimizer.
|
Base Optimizer.
|
||||||
|
@ -140,6 +152,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
super().__init__(optim=optimizer)
|
super().__init__(optim=optimizer)
|
||||||
|
|
||||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||||
|
self._dtype_memory = self.optim.param_groups[0]["params"][0].element_size()
|
||||||
self._cpu_offload = cpu_offload
|
self._cpu_offload = cpu_offload
|
||||||
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||||
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
|
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||||
|
@ -154,6 +167,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# fp16 and fp32 params for mixed precision training
|
# fp16 and fp32 params for mixed precision training
|
||||||
self._fp16_param_groups = dict()
|
self._fp16_param_groups = dict()
|
||||||
|
self._fp32_orig_param_groups_of_current_rank = dict()
|
||||||
self._fp32_flat_param_groups_of_current_rank = dict()
|
self._fp32_flat_param_groups_of_current_rank = dict()
|
||||||
|
|
||||||
# communication params
|
# communication params
|
||||||
|
@ -175,6 +189,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
) # 0: sender, 1: receiver
|
) # 0: sender, 1: receiver
|
||||||
self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank]
|
self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank]
|
||||||
self._fp32_flat_proxy_param_of_current_rank = None
|
self._fp32_flat_proxy_param_of_current_rank = None
|
||||||
|
self._fp32_proxy_param_groups_of_current_rank = None
|
||||||
|
|
||||||
|
self._proxy_param_gradients_of_current_rank = None
|
||||||
|
|
||||||
compensation_conf = {
|
compensation_conf = {
|
||||||
k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v
|
k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v
|
||||||
|
@ -196,10 +213,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# divided by 2.
|
# divided by 2.
|
||||||
self._memory_balance_amount = (
|
self._memory_balance_amount = (
|
||||||
(zero_cfg.cuda_memory_balance_amount * abs(_peer_local_rank - _self_local_rank) + _compensation_amount)
|
(zero_cfg.cuda_memory_balance_amount * abs(_peer_local_rank - _self_local_rank) + _compensation_amount)
|
||||||
/ 2
|
/ 2 # total -> need to move
|
||||||
/ 3
|
/ 3 # optim param, exp_avg, exp_avg_sq -> optim param
|
||||||
/ 2
|
|
||||||
)
|
)
|
||||||
|
# convert optimizer parameter dtype to model parameter dtype.
|
||||||
|
self._memory_balance_amount /= 4 / self._dtype_memory
|
||||||
|
|
||||||
# gradient scaler
|
# gradient scaler
|
||||||
self.grad_scaler = DynamicGradScaler(
|
self.grad_scaler = DynamicGradScaler(
|
||||||
|
@ -294,39 +312,53 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
if _enable_memory_balance and self._memory_balance_role == 0:
|
if _enable_memory_balance and self._memory_balance_role == 0:
|
||||||
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(self._zero_local_rank, group_id)
|
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(self._zero_local_rank, group_id)
|
||||||
|
proxy_params = self._param_store.get_fp16_params_by_rank_group(
|
||||||
|
self._zero_local_rank, group_id, option="proxy_only"
|
||||||
|
)
|
||||||
|
|
||||||
send_obj_meta(flat_proxy_param, next_rank=self._memory_balance_peer)
|
send_obj_meta(flat_proxy_param, next_rank=self._memory_balance_peer)
|
||||||
|
send_obj_meta(proxy_params, next_rank=self._memory_balance_peer)
|
||||||
dist.send(flat_proxy_param, dst=self._memory_balance_peer)
|
dist.send(flat_proxy_param, dst=self._memory_balance_peer)
|
||||||
elif _enable_memory_balance and self._memory_balance_role == 1:
|
elif _enable_memory_balance and self._memory_balance_role == 1:
|
||||||
flat_offload_shape = recv_obj_meta(prev_rank=self._memory_balance_peer)
|
flat_proxy_shape = recv_obj_meta(prev_rank=self._memory_balance_peer)
|
||||||
|
proxy_param_shapes = recv_obj_meta(prev_rank=self._memory_balance_peer)
|
||||||
|
|
||||||
flat_proxy_param = torch.empty(flat_offload_shape, device=get_current_device(), dtype=self._dtype)
|
# fix recv_obj_meta result when length of proxy_params is 1.
|
||||||
|
if isinstance(proxy_param_shapes, torch.Size):
|
||||||
|
proxy_param_shapes = [proxy_param_shapes]
|
||||||
|
|
||||||
|
flat_proxy_param = torch.empty(flat_proxy_shape, device=get_current_device(), dtype=self._dtype)
|
||||||
dist.recv(flat_proxy_param, src=self._memory_balance_peer)
|
dist.recv(flat_proxy_param, src=self._memory_balance_peer)
|
||||||
|
|
||||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||||
# No flat fp32 buffer is allocated if the process has no parameters.
|
# No flat fp32 buffer is allocated if the process has no parameters.
|
||||||
if self.param_group_has_params[group_id]:
|
if self.param_group_has_params[group_id]:
|
||||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
|
|
||||||
self._zero_local_rank, group_id
|
|
||||||
)
|
|
||||||
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
|
||||||
device = "cpu" if self._cpu_offload else get_current_device()
|
device = "cpu" if self._cpu_offload else get_current_device()
|
||||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
# parameters belong to current rank
|
||||||
fp32_flat_current_rank.requires_grad = True
|
fp16_params = self._param_store.get_fp16_params_by_rank_group(
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
|
self._zero_local_rank, group_id, option="without_proxy"
|
||||||
|
)
|
||||||
# need to replace the params in the `params` field in the optimizer
|
fp32_params = _create_fp32_param_copy(fp16_params, device)
|
||||||
# so that when the optimizer calls step(), it only updates the tensors
|
flat_fp32_param = _flatten_and_sync_params(fp32_params)
|
||||||
# managed by this data parallel rank
|
self._fp32_orig_param_groups_of_current_rank[group_id] = fp32_params
|
||||||
optim_params = [fp32_flat_current_rank]
|
self._fp32_flat_param_groups_of_current_rank[group_id] = flat_fp32_param
|
||||||
|
|
||||||
|
# proxy parameters
|
||||||
|
fp32_proxy_params = []
|
||||||
if _enable_memory_balance and self._memory_balance_role == 1:
|
if _enable_memory_balance and self._memory_balance_role == 1:
|
||||||
flat_proxy_param = flat_proxy_param.to(device=device, dtype=fp32_flat_current_rank.dtype)
|
# create empty tensor for fp32 proxy paramters
|
||||||
flat_proxy_param.requires_grad = True
|
for _shape in proxy_param_shapes:
|
||||||
optim_params.append(flat_proxy_param)
|
fp32_proxy_param = torch.empty(_shape, dtype=torch.float32, device=device)
|
||||||
|
fp32_proxy_param.requires_grad = True
|
||||||
|
fp32_proxy_params.append(fp32_proxy_param)
|
||||||
|
# sync with received flat fp32 proxy parameter
|
||||||
|
flat_proxy_param = flat_proxy_param.to(device=device, dtype=torch.float32)
|
||||||
|
sync_param(flat_proxy_param, fp32_proxy_params)
|
||||||
|
self._fp32_proxy_param_groups_of_current_rank = fp32_proxy_params
|
||||||
|
|
||||||
self._fp32_flat_proxy_param_of_current_rank = flat_proxy_param
|
self._fp32_flat_proxy_param_of_current_rank = flat_proxy_param
|
||||||
|
|
||||||
param_group["params"] = optim_params
|
param_group["params"] = fp32_params + fp32_proxy_params
|
||||||
|
|
||||||
# set reduction state
|
# set reduction state
|
||||||
for param in self._fp16_param_groups[group_id]:
|
for param in self._fp16_param_groups[group_id]:
|
||||||
|
@ -543,7 +575,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._param_store.reset_reduced_data_for_compute_norm()
|
self._param_store.reset_reduced_data_for_compute_norm()
|
||||||
|
|
||||||
# accumulate gradient
|
# accumulate gradient
|
||||||
proxy_gradinets = []
|
|
||||||
avg_gradients = self._grad_store._averaged_gradients
|
avg_gradients = self._grad_store._averaged_gradients
|
||||||
|
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
|
@ -568,29 +599,13 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
param_idx += 1
|
param_idx += 1
|
||||||
|
|
||||||
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
|
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
|
||||||
|
self._proxy_param_gradients_of_current_rank = []
|
||||||
param_group = self._param_store.get_fp16_params_by_rank_group(
|
param_group = self._param_store.get_fp16_params_by_rank_group(
|
||||||
self._zero_local_rank, group_id, option="proxy_only"
|
self._zero_local_rank, group_id, option="proxy_only"
|
||||||
)
|
)
|
||||||
for param in param_group:
|
for param in param_group:
|
||||||
assert param.grad is not None, "gradient of proxy parameter is None"
|
assert param.grad is not None, "gradient of proxy parameter is None"
|
||||||
proxy_gradinets.append(param.grad)
|
self._proxy_param_gradients_of_current_rank.append(param.grad)
|
||||||
|
|
||||||
# send offload gradients to reciever
|
|
||||||
if self._enable_memory_balance and self._memory_balance_role == 0:
|
|
||||||
flat_proxy_grads = flatten(proxy_gradinets)
|
|
||||||
|
|
||||||
dist.send(flat_proxy_grads, self._memory_balance_peer)
|
|
||||||
# torch.cuda.synchronize()
|
|
||||||
elif self._enable_memory_balance and self._enable_memory_balance == 1:
|
|
||||||
_shape = self._fp32_flat_proxy_param_of_current_rank.shape
|
|
||||||
_device = self._fp32_flat_proxy_param_of_current_rank.device
|
|
||||||
flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype)
|
|
||||||
|
|
||||||
dist.recv(flat_proxy_gradient, self._memory_balance_peer)
|
|
||||||
# torch.cuda.synchronize()
|
|
||||||
self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to(
|
|
||||||
dtype=self._fp32_flat_proxy_param_of_current_rank.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# the gradients needed are stored in the avg_gradients buffer
|
# the gradients needed are stored in the avg_gradients buffer
|
||||||
# thus, can clear this
|
# thus, can clear this
|
||||||
|
@ -745,32 +760,25 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
return False, norms
|
return False, norms
|
||||||
|
|
||||||
# copy the grad of fp16 param to fp32 param
|
# copy the grad of fp16 param to fp32 param
|
||||||
single_grad_partition_groups = []
|
grads_partition_groups = []
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
# compute norm
|
|
||||||
# The following operations are performed only on the rank to which parameters are assigned.
|
|
||||||
if not self.param_group_has_params[group_id]:
|
if not self.param_group_has_params[group_id]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# create flat gradient for the flat fp32 params
|
|
||||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||||
with torch.no_grad():
|
|
||||||
flat_fp16_avg_grads = flatten(gradients)
|
|
||||||
self._grad_store.reset_average_gradients_by_group(group_id)
|
self._grad_store.reset_average_gradients_by_group(group_id)
|
||||||
gradients = None # release cuda memory
|
fp32_params = self._fp32_orig_param_groups_of_current_rank[group_id]
|
||||||
|
grads_partition_groups.append([])
|
||||||
|
|
||||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
for idx, grad in enumerate(gradients):
|
||||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
fp32_grad = grad.data.float()
|
||||||
flat_fp16_avg_grads = None # release cuda memory
|
fp32_params[idx].grad = fp32_grad
|
||||||
|
grads_partition_groups[group_id].append(fp32_grad)
|
||||||
|
|
||||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
|
||||||
assert (
|
grads_partition_groups[group_id].extend(self._proxy_param_gradients_of_current_rank)
|
||||||
param_shape == flat_fp32_avg_grads.shape
|
|
||||||
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
|
||||||
|
|
||||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
gradients = None
|
||||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
|
||||||
|
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
# get the global norm
|
# get the global norm
|
||||||
|
@ -781,39 +789,50 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# the following operations are performed only on the rank to which parameters are assigned.
|
# the following operations are performed only on the rank to which parameters are assigned.
|
||||||
if gpc.config.model.dtype is not torch.float32:
|
if gpc.config.model.dtype is not torch.float32:
|
||||||
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
|
if len(grads_partition_groups) != 0 and self._clip_grad_norm > 0:
|
||||||
self._unscale_and_clip_grads(
|
self._unscale_and_clip_grads(
|
||||||
single_grad_partition_groups,
|
grads_partition_groups,
|
||||||
list(global_norm_groups.values()),
|
list(global_norm_groups.values()),
|
||||||
loss_scale,
|
loss_scale,
|
||||||
)
|
)
|
||||||
|
grads_partition_groups = None
|
||||||
|
|
||||||
# update the parameters
|
# update the parameters
|
||||||
timer("step").start()
|
timer("step").start()
|
||||||
|
|
||||||
|
# send and receive proxy gradients
|
||||||
|
if self._enable_memory_balance and self._memory_balance_role == 0:
|
||||||
|
for gradient in self._proxy_param_gradients_of_current_rank:
|
||||||
|
dist.send(gradient, dst=self._memory_balance_peer)
|
||||||
|
self._proxy_param_gradients_of_current_rank = None
|
||||||
|
elif self._enable_memory_balance and self._enable_memory_balance == 1:
|
||||||
|
for proxy_param in self._fp32_proxy_param_groups_of_current_rank:
|
||||||
|
proxy_gradient = torch.empty(proxy_param.shape, device=proxy_param.device, dtype=self._dtype)
|
||||||
|
dist.recv(proxy_gradient, self._memory_balance_peer)
|
||||||
|
proxy_param.grad = proxy_gradient.to(dtype=proxy_param.dtype)
|
||||||
|
|
||||||
# For those ranks that are not assigned parameters, we just wait for other ranks
|
# For those ranks that are not assigned parameters, we just wait for other ranks
|
||||||
# to send them updated their own parameters.
|
# to send them updated their own parameters.
|
||||||
if self.has_params:
|
if self.has_params:
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
# release the fp32 grad
|
# release the fp32 grad
|
||||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
for group_id in range(self.num_param_groups):
|
||||||
if self._enable_memory_balance and self._memory_balance_role == 1:
|
release_param_grad(self._fp32_orig_param_groups_of_current_rank[group_id])
|
||||||
self._fp32_flat_proxy_param_of_current_rank.grad = None
|
|
||||||
|
if self._enable_memory_balance and self._memory_balance_role == 1:
|
||||||
|
release_param_grad(self._fp32_proxy_param_groups_of_current_rank)
|
||||||
|
|
||||||
# receive proxy params
|
|
||||||
if self._enable_memory_balance and self._memory_balance_role == 0:
|
if self._enable_memory_balance and self._memory_balance_role == 0:
|
||||||
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(
|
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(
|
||||||
rank=self._zero_local_rank, group_id=0
|
rank=self._zero_local_rank, group_id=0
|
||||||
)
|
)
|
||||||
dist.recv(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
|
dist.recv(flat_proxy_param, self._memory_balance_peer)
|
||||||
# torch.cuda.synchronize()
|
|
||||||
elif self._enable_memory_balance and self._memory_balance_role == 1:
|
elif self._enable_memory_balance and self._memory_balance_role == 1:
|
||||||
flat_proxy_param = self._fp32_flat_proxy_param_of_current_rank.to(dtype=self._dtype)
|
flat_proxy_param = self._fp32_flat_proxy_param_of_current_rank.to(dtype=self._dtype)
|
||||||
dist.send(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
|
dist.send(flat_proxy_param, self._memory_balance_peer)
|
||||||
# torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# update fp16 partition updated by the current rank
|
# update fp16 partition updated by the current rank
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
for group_id in range(self.num_param_groups):
|
||||||
if self.param_group_has_params[group_id]:
|
if self.param_group_has_params[group_id]:
|
||||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
rank=self._zero_local_rank, group_id=group_id
|
rank=self._zero_local_rank, group_id=group_id
|
||||||
|
@ -894,7 +913,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
return self._found_overflow.item() > 0
|
return self._found_overflow.item() > 0
|
||||||
|
|
||||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale):
|
def _unscale_and_clip_grads(self, gradients_groups, total_norm_groups, loss_scale):
|
||||||
# compute combined scale factor for this group
|
# compute combined scale factor for this group
|
||||||
combined_scale_groups = []
|
combined_scale_groups = []
|
||||||
|
|
||||||
|
@ -906,7 +925,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if clip > 1.0:
|
if clip > 1.0:
|
||||||
combined_scale_groups[group_id] = clip * loss_scale
|
combined_scale_groups[group_id] = clip * loss_scale
|
||||||
|
|
||||||
for group_id, grad in enumerate(grad_groups_flat):
|
for group_id, grads in enumerate(gradients_groups):
|
||||||
|
for grad in grads:
|
||||||
grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||||
|
|
||||||
def clip_grad_norm(self, model, max_norm):
|
def clip_grad_norm(self, model, max_norm):
|
||||||
|
|
Loading…
Reference in New Issue