diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index ee42b4a..2f53963 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -65,6 +65,18 @@ def _flatten_and_sync_params(tensors: List[torch.Tensor]) -> torch.Tensor: return flat_tensor +def _create_fp32_param_copy(parameters: List[torch.Tensor], device: torch.device) -> List[torch.Tensor]: + fp32_params = [] + + # create fp32 parameter copy + for param in parameters: + fp32_param = param.data.to(device, dtype=torch.float32) + fp32_param.requires_grad = True + fp32_params.append(fp32_param) + + return fp32_params + + class BaseOptimizer(Optimizer): """ Base Optimizer. @@ -140,6 +152,7 @@ class HybridZeroOptimizer(BaseOptimizer): super().__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]["params"][0].dtype + self._dtype_memory = self.optim.param_groups[0]["params"][0].element_size() self._cpu_offload = cpu_offload self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1) self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1) @@ -154,6 +167,7 @@ class HybridZeroOptimizer(BaseOptimizer): # fp16 and fp32 params for mixed precision training self._fp16_param_groups = dict() + self._fp32_orig_param_groups_of_current_rank = dict() self._fp32_flat_param_groups_of_current_rank = dict() # communication params @@ -175,6 +189,9 @@ class HybridZeroOptimizer(BaseOptimizer): ) # 0: sender, 1: receiver self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank] self._fp32_flat_proxy_param_of_current_rank = None + self._fp32_proxy_param_groups_of_current_rank = None + + self._proxy_param_gradients_of_current_rank = None compensation_conf = { k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v @@ -196,10 +213,11 @@ class HybridZeroOptimizer(BaseOptimizer): # divided by 2. self._memory_balance_amount = ( (zero_cfg.cuda_memory_balance_amount * abs(_peer_local_rank - _self_local_rank) + _compensation_amount) - / 2 - / 3 - / 2 + / 2 # total -> need to move + / 3 # optim param, exp_avg, exp_avg_sq -> optim param ) + # convert optimizer parameter dtype to model parameter dtype. + self._memory_balance_amount /= 4 / self._dtype_memory # gradient scaler self.grad_scaler = DynamicGradScaler( @@ -294,39 +312,53 @@ class HybridZeroOptimizer(BaseOptimizer): if _enable_memory_balance and self._memory_balance_role == 0: flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(self._zero_local_rank, group_id) + proxy_params = self._param_store.get_fp16_params_by_rank_group( + self._zero_local_rank, group_id, option="proxy_only" + ) send_obj_meta(flat_proxy_param, next_rank=self._memory_balance_peer) + send_obj_meta(proxy_params, next_rank=self._memory_balance_peer) dist.send(flat_proxy_param, dst=self._memory_balance_peer) elif _enable_memory_balance and self._memory_balance_role == 1: - flat_offload_shape = recv_obj_meta(prev_rank=self._memory_balance_peer) + flat_proxy_shape = recv_obj_meta(prev_rank=self._memory_balance_peer) + proxy_param_shapes = recv_obj_meta(prev_rank=self._memory_balance_peer) - flat_proxy_param = torch.empty(flat_offload_shape, device=get_current_device(), dtype=self._dtype) + # fix recv_obj_meta result when length of proxy_params is 1. + if isinstance(proxy_param_shapes, torch.Size): + proxy_param_shapes = [proxy_param_shapes] + + flat_proxy_param = torch.empty(flat_proxy_shape, device=get_current_device(), dtype=self._dtype) dist.recv(flat_proxy_param, src=self._memory_balance_peer) # create a copy of fp32 weights of the parameters for which this rank is responsible # No flat fp32 buffer is allocated if the process has no parameters. if self.param_group_has_params[group_id]: - fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group( - self._zero_local_rank, group_id - ) - fp32_flat_current_rank = fp16_flat_current_rank.float() device = "cpu" if self._cpu_offload else get_current_device() - fp32_flat_current_rank = fp32_flat_current_rank.to(device) - fp32_flat_current_rank.requires_grad = True - self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank - - # need to replace the params in the `params` field in the optimizer - # so that when the optimizer calls step(), it only updates the tensors - # managed by this data parallel rank - optim_params = [fp32_flat_current_rank] + # parameters belong to current rank + fp16_params = self._param_store.get_fp16_params_by_rank_group( + self._zero_local_rank, group_id, option="without_proxy" + ) + fp32_params = _create_fp32_param_copy(fp16_params, device) + flat_fp32_param = _flatten_and_sync_params(fp32_params) + self._fp32_orig_param_groups_of_current_rank[group_id] = fp32_params + self._fp32_flat_param_groups_of_current_rank[group_id] = flat_fp32_param + # proxy parameters + fp32_proxy_params = [] if _enable_memory_balance and self._memory_balance_role == 1: - flat_proxy_param = flat_proxy_param.to(device=device, dtype=fp32_flat_current_rank.dtype) - flat_proxy_param.requires_grad = True - optim_params.append(flat_proxy_param) + # create empty tensor for fp32 proxy paramters + for _shape in proxy_param_shapes: + fp32_proxy_param = torch.empty(_shape, dtype=torch.float32, device=device) + fp32_proxy_param.requires_grad = True + fp32_proxy_params.append(fp32_proxy_param) + # sync with received flat fp32 proxy parameter + flat_proxy_param = flat_proxy_param.to(device=device, dtype=torch.float32) + sync_param(flat_proxy_param, fp32_proxy_params) + self._fp32_proxy_param_groups_of_current_rank = fp32_proxy_params + self._fp32_flat_proxy_param_of_current_rank = flat_proxy_param - param_group["params"] = optim_params + param_group["params"] = fp32_params + fp32_proxy_params # set reduction state for param in self._fp16_param_groups[group_id]: @@ -543,7 +575,6 @@ class HybridZeroOptimizer(BaseOptimizer): self._param_store.reset_reduced_data_for_compute_norm() # accumulate gradient - proxy_gradinets = [] avg_gradients = self._grad_store._averaged_gradients for group_id in range(self.num_param_groups): @@ -568,29 +599,13 @@ class HybridZeroOptimizer(BaseOptimizer): param_idx += 1 if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0: + self._proxy_param_gradients_of_current_rank = [] param_group = self._param_store.get_fp16_params_by_rank_group( self._zero_local_rank, group_id, option="proxy_only" ) for param in param_group: assert param.grad is not None, "gradient of proxy parameter is None" - proxy_gradinets.append(param.grad) - - # send offload gradients to reciever - if self._enable_memory_balance and self._memory_balance_role == 0: - flat_proxy_grads = flatten(proxy_gradinets) - - dist.send(flat_proxy_grads, self._memory_balance_peer) - # torch.cuda.synchronize() - elif self._enable_memory_balance and self._enable_memory_balance == 1: - _shape = self._fp32_flat_proxy_param_of_current_rank.shape - _device = self._fp32_flat_proxy_param_of_current_rank.device - flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype) - - dist.recv(flat_proxy_gradient, self._memory_balance_peer) - # torch.cuda.synchronize() - self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to( - dtype=self._fp32_flat_proxy_param_of_current_rank.dtype - ) + self._proxy_param_gradients_of_current_rank.append(param.grad) # the gradients needed are stored in the avg_gradients buffer # thus, can clear this @@ -745,32 +760,25 @@ class HybridZeroOptimizer(BaseOptimizer): return False, norms # copy the grad of fp16 param to fp32 param - single_grad_partition_groups = [] + grads_partition_groups = [] for group_id in range(self.num_param_groups): - # compute norm - # The following operations are performed only on the rank to which parameters are assigned. if not self.param_group_has_params[group_id]: continue - # create flat gradient for the flat fp32 params gradients = self._grad_store.get_averaged_gradients_by_group(group_id) - with torch.no_grad(): - flat_fp16_avg_grads = flatten(gradients) self._grad_store.reset_average_gradients_by_group(group_id) - gradients = None # release cuda memory + fp32_params = self._fp32_orig_param_groups_of_current_rank[group_id] + grads_partition_groups.append([]) - dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype - flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) - flat_fp16_avg_grads = None # release cuda memory + for idx, grad in enumerate(gradients): + fp32_grad = grad.data.float() + fp32_params[idx].grad = fp32_grad + grads_partition_groups[group_id].append(fp32_grad) - param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape - assert ( - param_shape == flat_fp32_avg_grads.shape - ), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}" + if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0: + grads_partition_groups[group_id].extend(self._proxy_param_gradients_of_current_rank) - single_grad_partition_groups.append(flat_fp32_avg_grads) - device = self._fp32_flat_param_groups_of_current_rank[group_id].device - self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) + gradients = None # unscale and clip grads # get the global norm @@ -781,39 +789,50 @@ class HybridZeroOptimizer(BaseOptimizer): # the following operations are performed only on the rank to which parameters are assigned. if gpc.config.model.dtype is not torch.float32: - if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0: + if len(grads_partition_groups) != 0 and self._clip_grad_norm > 0: self._unscale_and_clip_grads( - single_grad_partition_groups, + grads_partition_groups, list(global_norm_groups.values()), loss_scale, ) + grads_partition_groups = None # update the parameters timer("step").start() + # send and receive proxy gradients + if self._enable_memory_balance and self._memory_balance_role == 0: + for gradient in self._proxy_param_gradients_of_current_rank: + dist.send(gradient, dst=self._memory_balance_peer) + self._proxy_param_gradients_of_current_rank = None + elif self._enable_memory_balance and self._enable_memory_balance == 1: + for proxy_param in self._fp32_proxy_param_groups_of_current_rank: + proxy_gradient = torch.empty(proxy_param.shape, device=proxy_param.device, dtype=self._dtype) + dist.recv(proxy_gradient, self._memory_balance_peer) + proxy_param.grad = proxy_gradient.to(dtype=proxy_param.dtype) + # For those ranks that are not assigned parameters, we just wait for other ranks # to send them updated their own parameters. if self.has_params: self.optim.step() # release the fp32 grad - release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) - if self._enable_memory_balance and self._memory_balance_role == 1: - self._fp32_flat_proxy_param_of_current_rank.grad = None + for group_id in range(self.num_param_groups): + release_param_grad(self._fp32_orig_param_groups_of_current_rank[group_id]) + + if self._enable_memory_balance and self._memory_balance_role == 1: + release_param_grad(self._fp32_proxy_param_groups_of_current_rank) - # receive proxy params if self._enable_memory_balance and self._memory_balance_role == 0: flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group( rank=self._zero_local_rank, group_id=0 ) - dist.recv(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE)) - # torch.cuda.synchronize() + dist.recv(flat_proxy_param, self._memory_balance_peer) elif self._enable_memory_balance and self._memory_balance_role == 1: flat_proxy_param = self._fp32_flat_proxy_param_of_current_rank.to(dtype=self._dtype) - dist.send(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE)) - # torch.cuda.synchronize() + dist.send(flat_proxy_param, self._memory_balance_peer) # update fp16 partition updated by the current rank - for group_id in range(len(self._fp16_param_groups)): + for group_id in range(self.num_param_groups): if self.param_group_has_params[group_id]: fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( rank=self._zero_local_rank, group_id=group_id @@ -894,7 +913,7 @@ class HybridZeroOptimizer(BaseOptimizer): return self._found_overflow.item() > 0 - def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale): + def _unscale_and_clip_grads(self, gradients_groups, total_norm_groups, loss_scale): # compute combined scale factor for this group combined_scale_groups = [] @@ -906,8 +925,9 @@ class HybridZeroOptimizer(BaseOptimizer): if clip > 1.0: combined_scale_groups[group_id] = clip * loss_scale - for group_id, grad in enumerate(grad_groups_flat): - grad.data.mul_(1.0 / combined_scale_groups[group_id]) + for group_id, grads in enumerate(gradients_groups): + for grad in grads: + grad.data.mul_(1.0 / combined_scale_groups[group_id]) def clip_grad_norm(self, model, max_norm): # will conduct in the step()