diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 94c3e05..f3b2fe1 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -87,6 +87,7 @@ class HybridZeroOptimizer(BaseOptimizer): overlap_broadcast=False, grad_scal_cfg: Config = None, zero_cfg: Config = None, + use_fp16: bool = True, ): # DynamicGradScaler related args initial_scale = grad_scal_cfg.fp16.initial_scale @@ -104,6 +105,7 @@ class HybridZeroOptimizer(BaseOptimizer): super().__init__(optim=optimizer) + self.use_fp16 = use_fp16 self._dtype = self.optim.param_groups[0]["params"][0].dtype self._cpu_offload = cpu_offload self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1) @@ -125,14 +127,18 @@ class HybridZeroOptimizer(BaseOptimizer): self._reduce_bucket_size = reduce_bucket_size # gradient scaler - self.grad_scaler = DynamicGradScaler( - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, + self.grad_scaler = ( + DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + if self.use_fp16 + else None ) self._found_overflow = torch.cuda.FloatTensor([0], device=get_current_device()) @@ -176,11 +182,14 @@ class HybridZeroOptimizer(BaseOptimizer): for param in params: self._param_store.set_param_to_rank(param, rank) + # flatten the reordered tensors # move to cpu to make room to create the flat tensor + # Even for fp32 training, we will still flattend the tensor, + # which will not increase the use of GPU memory, + # and can improve the efficiency of broadcasting. for param in group_params: param.data = param.data.cpu() - # flatten the reordered tensors for rank in range(self._zero_world_size): # No flat fp16 buffer is allocated if the process has no parameters. if rank not in self.param_group_no_params_ranks[group_id]: @@ -194,19 +203,25 @@ class HybridZeroOptimizer(BaseOptimizer): # create a copy of fp32 weights of the parameters for which this rank is responsible # No flat fp32 buffer is allocated if the process has no parameters. if self.param_group_has_params[group_id]: - fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group( - self._zero_local_rank, group_id - ) - fp32_flat_current_rank = fp16_flat_current_rank.float() - device = "cpu" if self._cpu_offload else get_current_device() - fp32_flat_current_rank = fp32_flat_current_rank.to(device) - fp32_flat_current_rank.requires_grad = True - self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank + if self.use_fp16: + fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group( + self._zero_local_rank, group_id + ) + fp32_flat_current_rank = fp16_flat_current_rank.float() + device = "cpu" if self._cpu_offload else get_current_device() + fp32_flat_current_rank = fp32_flat_current_rank.to(device) + fp32_flat_current_rank.requires_grad = True + self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank - # need to replace the params in the `params` field in the optimizer - # so that when the optimizer calls step(), it only updates the tensors - # managed by this data parallel rank - param_group["params"] = [fp32_flat_current_rank] + # need to replace the params in the `params` field in the optimizer + # so that when the optimizer calls step(), it only updates the tensors + # managed by this data parallel rank + param_group["params"] = [fp32_flat_current_rank] + else: + # use fp32 + param_group["params"] = self._param_store.get_fp16_params_by_rank_group( + self._zero_local_rank, group_id + ) # set reduction state for param in self._fp16_param_groups[group_id]: @@ -243,7 +258,10 @@ class HybridZeroOptimizer(BaseOptimizer): @property def loss_scale(self): - return self.grad_scaler.scale + if self.grad_scaler is None: + return 1 + else: + return self.grad_scaler.scale @property def num_param_groups(self): @@ -533,7 +551,8 @@ class HybridZeroOptimizer(BaseOptimizer): norm_groups.append(norm_group) loss_scale = float(self.loss_scale.item()) # backup - self.grad_scaler.update(found_inf) + if self.grad_scaler: + self.grad_scaler.update(found_inf) # update loss scale if overflow occurs if found_inf: if gpc.is_rank_for_log(): @@ -552,21 +571,30 @@ class HybridZeroOptimizer(BaseOptimizer): continue gradients = self._grad_store.get_averaged_gradients_by_group(group_id) - # create flat gradient for the flat fp32 params - fp16_avg_grads = gradients - flat_fp16_avg_grads = flatten(fp16_avg_grads) + if self.use_fp16: + # create flat gradient for the flat fp32 params + fp16_avg_grads = gradients + flat_fp16_avg_grads = flatten(fp16_avg_grads) - dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype - flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) + dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype + flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) - param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape - assert ( - param_shape == flat_fp32_avg_grads.shape - ), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}" + param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape + assert ( + param_shape == flat_fp32_avg_grads.shape + ), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}" + + single_grad_partition_groups.append(flat_fp32_avg_grads) + device = self._fp32_flat_param_groups_of_current_rank[group_id].device + self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) + else: + assert len(gradients) == len(self.optim.param_groups[group_id]["params"]), ( + len(gradients), + len(self.optim.param_groups[group_id]["params"]), + ) + for g, p in zip(gradients, self.optim.param_groups[group_id]["params"]): + p.grad = g - single_grad_partition_groups.append(flat_fp32_avg_grads) - device = self._fp32_flat_param_groups_of_current_rank[group_id].device - self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = [] @@ -576,8 +604,9 @@ class HybridZeroOptimizer(BaseOptimizer): global_norm = sum(norm_groups) ** 0.5 # the following operations are performed only on the rank to which parameters are assigned. - if len(single_grad_partition_groups) != 0: - self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale) + if self.use_fp16: + if len(single_grad_partition_groups) != 0: + self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale) timer("cal_norm").stop() # update the parameters @@ -588,15 +617,16 @@ class HybridZeroOptimizer(BaseOptimizer): if self.has_params: self.optim.step() # release the fp32 grad - release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) - # update fp16 partition updated by the current rank - for group_id in range(len(self._fp16_param_groups)): - if self.param_group_has_params[group_id]: - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( - rank=self._zero_local_rank, group_id=group_id - ) - fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] - fp16_param.data.copy_(fp32_param) + if self.use_fp16: + release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) + # update fp16 partition updated by the current rank + for group_id in range(len(self._fp16_param_groups)): + if self.param_group_has_params[group_id]: + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( + rank=self._zero_local_rank, group_id=group_id + ) + fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] + fp16_param.data.copy_(fp32_param) # TODO: support broadcast overlap self.broadcast_params(overlap=False) @@ -614,8 +644,6 @@ class HybridZeroOptimizer(BaseOptimizer): # The following operations are performed only on the rank to which parameters are assigned. if rank not in self.param_group_no_params_ranks[group_id]: fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) - # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank - # assert grank == rank, f"{grank} == {rank}" g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank] handle = dist.broadcast( fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True @@ -667,48 +695,52 @@ class HybridZeroOptimizer(BaseOptimizer): def state_dict(self): states = {} - grad_scaler = self.grad_scaler.state_dict() - states["grad_scaler"] = grad_scaler optim_states = self.optim.state_dict() states["base_optim_states"] = optim_states - flat_fp32_weights = {} - for group_id, param in self._fp32_flat_param_groups_of_current_rank.items(): - if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: - assert param.grad is None - flat_fp32_weights[group_id] = param - states["flat_fp32_weights"] = flat_fp32_weights + if self.use_fp16: + grad_scaler = self.grad_scaler.state_dict() + states["grad_scaler"] = grad_scaler + + flat_fp32_weights = {} + for group_id, param in self._fp32_flat_param_groups_of_current_rank.items(): + if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: + assert param.grad is None + flat_fp32_weights[group_id] = param + states["flat_fp32_weights"] = flat_fp32_weights states["zero_devide_optim_plan"] = self.params_per_rank_id_dict return states def load_state_dict(self, states): # TODO: Need to take into account the change in the number of DP. - assert "grad_scaler" in states, "Not found grad_scaler state!" - grad_scaler = states["grad_scaler"] - self.grad_scaler.load_state_dict(grad_scaler) optim_states = states["base_optim_states"] self.optim.load_state_dict(optim_states) - # load fp32 model weight. - flat_fp32_weights = states["flat_fp32_weights"] - assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank) - for group_id, param in flat_fp32_weights.items(): - if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: - self_param = self._fp32_flat_param_groups_of_current_rank[group_id] - assert ( - self_param.shape == param.shape - ), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}" - self_param.data.copy_(param.data) + if self.use_fp16: + assert "grad_scaler" in states, "Not found grad_scaler state!" + grad_scaler = states["grad_scaler"] + self.grad_scaler.load_state_dict(grad_scaler) - # Load the fp16 model weights. - for group_id in range(len(self._fp16_param_groups)): - if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( - rank=self._zero_local_rank, group_id=group_id - ) - fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] - fp16_param.data.copy_(fp32_param) + # load fp32 model weight. + flat_fp32_weights = states["flat_fp32_weights"] + assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank) + for group_id, param in flat_fp32_weights.items(): + if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: + self_param = self._fp32_flat_param_groups_of_current_rank[group_id] + assert ( + self_param.shape == param.shape + ), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}" + self_param.data.copy_(param.data) + + # Load the fp16 model weights. + for group_id in range(len(self._fp16_param_groups)): + if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( + rank=self._zero_local_rank, group_id=group_id + ) + fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] + fp16_param.data.copy_(fp32_param) if "zero_devide_optim_plan" in states: self.params_per_rank_id_dict = states["zero_devide_optim_plan"] diff --git a/train.py b/train.py index e2bd096..1067395 100644 --- a/train.py +++ b/train.py @@ -282,7 +282,10 @@ def initialize_optimizer(model: nn.Module): ) optimizer = HybridZeroOptimizer( - naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer + naive_optimizer, + grad_scal_cfg=gpc.config.grad_scaler, + zero_cfg=gpc.config.hybrid_zero_optimizer, + use_fp16= gpc.config.model.dtype is torch.float32, ) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)