feat(hybrid_zero_optim): reduce optimizer peek memory

pull/306/head
mwiacx 2023-09-21 14:22:08 +08:00
parent 10f01c4e08
commit fda0a96def
1 changed files with 91 additions and 71 deletions

View File

@ -65,6 +65,18 @@ def _flatten_and_sync_params(tensors: List[torch.Tensor]) -> torch.Tensor:
return flat_tensor
def _create_fp32_param_copy(parameters: List[torch.Tensor], device: torch.device) -> List[torch.Tensor]:
fp32_params = []
# create fp32 parameter copy
for param in parameters:
fp32_param = param.data.to(device, dtype=torch.float32)
fp32_param.requires_grad = True
fp32_params.append(fp32_param)
return fp32_params
class BaseOptimizer(Optimizer):
"""
Base Optimizer.
@ -140,6 +152,7 @@ class HybridZeroOptimizer(BaseOptimizer):
super().__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._dtype_memory = self.optim.param_groups[0]["params"][0].element_size()
self._cpu_offload = cpu_offload
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
@ -154,6 +167,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
self._fp32_orig_param_groups_of_current_rank = dict()
self._fp32_flat_param_groups_of_current_rank = dict()
# communication params
@ -175,6 +189,9 @@ class HybridZeroOptimizer(BaseOptimizer):
) # 0: sender, 1: receiver
self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank]
self._fp32_flat_proxy_param_of_current_rank = None
self._fp32_proxy_param_groups_of_current_rank = None
self._proxy_param_gradients_of_current_rank = None
compensation_conf = {
k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v
@ -196,10 +213,11 @@ class HybridZeroOptimizer(BaseOptimizer):
# divided by 2.
self._memory_balance_amount = (
(zero_cfg.cuda_memory_balance_amount * abs(_peer_local_rank - _self_local_rank) + _compensation_amount)
/ 2
/ 3
/ 2
/ 2 # total -> need to move
/ 3 # optim param, exp_avg, exp_avg_sq -> optim param
)
# convert optimizer parameter dtype to model parameter dtype.
self._memory_balance_amount /= 4 / self._dtype_memory
# gradient scaler
self.grad_scaler = DynamicGradScaler(
@ -294,39 +312,53 @@ class HybridZeroOptimizer(BaseOptimizer):
if _enable_memory_balance and self._memory_balance_role == 0:
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(self._zero_local_rank, group_id)
proxy_params = self._param_store.get_fp16_params_by_rank_group(
self._zero_local_rank, group_id, option="proxy_only"
)
send_obj_meta(flat_proxy_param, next_rank=self._memory_balance_peer)
send_obj_meta(proxy_params, next_rank=self._memory_balance_peer)
dist.send(flat_proxy_param, dst=self._memory_balance_peer)
elif _enable_memory_balance and self._memory_balance_role == 1:
flat_offload_shape = recv_obj_meta(prev_rank=self._memory_balance_peer)
flat_proxy_shape = recv_obj_meta(prev_rank=self._memory_balance_peer)
proxy_param_shapes = recv_obj_meta(prev_rank=self._memory_balance_peer)
flat_proxy_param = torch.empty(flat_offload_shape, device=get_current_device(), dtype=self._dtype)
# fix recv_obj_meta result when length of proxy_params is 1.
if isinstance(proxy_param_shapes, torch.Size):
proxy_param_shapes = [proxy_param_shapes]
flat_proxy_param = torch.empty(flat_proxy_shape, device=get_current_device(), dtype=self._dtype)
dist.recv(flat_proxy_param, src=self._memory_balance_peer)
# create a copy of fp32 weights of the parameters for which this rank is responsible
# No flat fp32 buffer is allocated if the process has no parameters.
if self.param_group_has_params[group_id]:
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
self._zero_local_rank, group_id
)
fp32_flat_current_rank = fp16_flat_current_rank.float()
device = "cpu" if self._cpu_offload else get_current_device()
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
fp32_flat_current_rank.requires_grad = True
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
optim_params = [fp32_flat_current_rank]
# parameters belong to current rank
fp16_params = self._param_store.get_fp16_params_by_rank_group(
self._zero_local_rank, group_id, option="without_proxy"
)
fp32_params = _create_fp32_param_copy(fp16_params, device)
flat_fp32_param = _flatten_and_sync_params(fp32_params)
self._fp32_orig_param_groups_of_current_rank[group_id] = fp32_params
self._fp32_flat_param_groups_of_current_rank[group_id] = flat_fp32_param
# proxy parameters
fp32_proxy_params = []
if _enable_memory_balance and self._memory_balance_role == 1:
flat_proxy_param = flat_proxy_param.to(device=device, dtype=fp32_flat_current_rank.dtype)
flat_proxy_param.requires_grad = True
optim_params.append(flat_proxy_param)
# create empty tensor for fp32 proxy paramters
for _shape in proxy_param_shapes:
fp32_proxy_param = torch.empty(_shape, dtype=torch.float32, device=device)
fp32_proxy_param.requires_grad = True
fp32_proxy_params.append(fp32_proxy_param)
# sync with received flat fp32 proxy parameter
flat_proxy_param = flat_proxy_param.to(device=device, dtype=torch.float32)
sync_param(flat_proxy_param, fp32_proxy_params)
self._fp32_proxy_param_groups_of_current_rank = fp32_proxy_params
self._fp32_flat_proxy_param_of_current_rank = flat_proxy_param
param_group["params"] = optim_params
param_group["params"] = fp32_params + fp32_proxy_params
# set reduction state
for param in self._fp16_param_groups[group_id]:
@ -543,7 +575,6 @@ class HybridZeroOptimizer(BaseOptimizer):
self._param_store.reset_reduced_data_for_compute_norm()
# accumulate gradient
proxy_gradinets = []
avg_gradients = self._grad_store._averaged_gradients
for group_id in range(self.num_param_groups):
@ -568,29 +599,13 @@ class HybridZeroOptimizer(BaseOptimizer):
param_idx += 1
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
self._proxy_param_gradients_of_current_rank = []
param_group = self._param_store.get_fp16_params_by_rank_group(
self._zero_local_rank, group_id, option="proxy_only"
)
for param in param_group:
assert param.grad is not None, "gradient of proxy parameter is None"
proxy_gradinets.append(param.grad)
# send offload gradients to reciever
if self._enable_memory_balance and self._memory_balance_role == 0:
flat_proxy_grads = flatten(proxy_gradinets)
dist.send(flat_proxy_grads, self._memory_balance_peer)
# torch.cuda.synchronize()
elif self._enable_memory_balance and self._enable_memory_balance == 1:
_shape = self._fp32_flat_proxy_param_of_current_rank.shape
_device = self._fp32_flat_proxy_param_of_current_rank.device
flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype)
dist.recv(flat_proxy_gradient, self._memory_balance_peer)
# torch.cuda.synchronize()
self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to(
dtype=self._fp32_flat_proxy_param_of_current_rank.dtype
)
self._proxy_param_gradients_of_current_rank.append(param.grad)
# the gradients needed are stored in the avg_gradients buffer
# thus, can clear this
@ -745,32 +760,25 @@ class HybridZeroOptimizer(BaseOptimizer):
return False, norms
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []
grads_partition_groups = []
for group_id in range(self.num_param_groups):
# compute norm
# The following operations are performed only on the rank to which parameters are assigned.
if not self.param_group_has_params[group_id]:
continue
# create flat gradient for the flat fp32 params
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
with torch.no_grad():
flat_fp16_avg_grads = flatten(gradients)
self._grad_store.reset_average_gradients_by_group(group_id)
gradients = None # release cuda memory
fp32_params = self._fp32_orig_param_groups_of_current_rank[group_id]
grads_partition_groups.append([])
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
flat_fp16_avg_grads = None # release cuda memory
for idx, grad in enumerate(gradients):
fp32_grad = grad.data.float()
fp32_params[idx].grad = fp32_grad
grads_partition_groups[group_id].append(fp32_grad)
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
assert (
param_shape == flat_fp32_avg_grads.shape
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
grads_partition_groups[group_id].extend(self._proxy_param_gradients_of_current_rank)
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
gradients = None
# unscale and clip grads
# get the global norm
@ -781,39 +789,50 @@ class HybridZeroOptimizer(BaseOptimizer):
# the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
if len(grads_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(
single_grad_partition_groups,
grads_partition_groups,
list(global_norm_groups.values()),
loss_scale,
)
grads_partition_groups = None
# update the parameters
timer("step").start()
# send and receive proxy gradients
if self._enable_memory_balance and self._memory_balance_role == 0:
for gradient in self._proxy_param_gradients_of_current_rank:
dist.send(gradient, dst=self._memory_balance_peer)
self._proxy_param_gradients_of_current_rank = None
elif self._enable_memory_balance and self._enable_memory_balance == 1:
for proxy_param in self._fp32_proxy_param_groups_of_current_rank:
proxy_gradient = torch.empty(proxy_param.shape, device=proxy_param.device, dtype=self._dtype)
dist.recv(proxy_gradient, self._memory_balance_peer)
proxy_param.grad = proxy_gradient.to(dtype=proxy_param.dtype)
# For those ranks that are not assigned parameters, we just wait for other ranks
# to send them updated their own parameters.
if self.has_params:
self.optim.step()
# release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
if self._enable_memory_balance and self._memory_balance_role == 1:
self._fp32_flat_proxy_param_of_current_rank.grad = None
for group_id in range(self.num_param_groups):
release_param_grad(self._fp32_orig_param_groups_of_current_rank[group_id])
if self._enable_memory_balance and self._memory_balance_role == 1:
release_param_grad(self._fp32_proxy_param_groups_of_current_rank)
# receive proxy params
if self._enable_memory_balance and self._memory_balance_role == 0:
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(
rank=self._zero_local_rank, group_id=0
)
dist.recv(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
# torch.cuda.synchronize()
dist.recv(flat_proxy_param, self._memory_balance_peer)
elif self._enable_memory_balance and self._memory_balance_role == 1:
flat_proxy_param = self._fp32_flat_proxy_param_of_current_rank.to(dtype=self._dtype)
dist.send(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
# torch.cuda.synchronize()
dist.send(flat_proxy_param, self._memory_balance_peer)
# update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)):
for group_id in range(self.num_param_groups):
if self.param_group_has_params[group_id]:
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
rank=self._zero_local_rank, group_id=group_id
@ -894,7 +913,7 @@ class HybridZeroOptimizer(BaseOptimizer):
return self._found_overflow.item() > 0
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale):
def _unscale_and_clip_grads(self, gradients_groups, total_norm_groups, loss_scale):
# compute combined scale factor for this group
combined_scale_groups = []
@ -906,7 +925,8 @@ class HybridZeroOptimizer(BaseOptimizer):
if clip > 1.0:
combined_scale_groups[group_id] = clip * loss_scale
for group_id, grad in enumerate(grad_groups_flat):
for group_id, grads in enumerate(gradients_groups):
for grad in grads:
grad.data.mul_(1.0 / combined_scale_groups[group_id])
def clip_grad_norm(self, model, max_norm):