diff --git a/configs/20B_sft.py b/configs/20B_sft.py index 1d093ef..bc63d34 100644 --- a/configs/20B_sft.py +++ b/configs/20B_sft.py @@ -1,4 +1,4 @@ -JOB_NAME = "13b_train" +JOB_NAME = "20b_train" DO_ALERT = False SEQ_LEN = 4096 @@ -51,7 +51,7 @@ data = dict( # micro_num means the number of micro_batch contained in one gradient update micro_num=4, # packed_length = micro_bsz * SEQ_LEN - micro_bsz=4, + micro_bsz=2, # defaults to the value of micro_num valid_micro_num=4, # defaults to 0, means disable evaluate diff --git a/internlm/model/linear.py b/internlm/model/linear.py index cc9524a..0ea6ee3 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -423,7 +423,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler: self.process_group = process_group self.FSTP_blocks = [] self.FSTP_outs = [] - self.FSTP_wqkvs = [] self.FSTP_modules = [] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle @@ -465,9 +464,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler: if name == "out_proj": self.FSTP_outs.append(child) self.module_to_index[child] = idx - if name == "Wqkv": - self.FSTP_wqkvs.append(child) - self.module_to_index[child] = idx if isinstance(child, FSTPLinear): self.module_to_index[child] = idx self.block_module[idx][index] = child @@ -488,7 +484,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler: self.embedding.append(children) def _all_gather_block_weight(self, block_index: int): - block = self.index_to_block[block_index] + #block = self.index_to_block[block_index] fsdp_modules = self.index_to_fsdp_modules[block_index] # self.block_handles[block] = [] for module in fsdp_modules: @@ -552,12 +548,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler: if module in self.FSTP_global_handle: del self.FSTP_global_handle[module] - def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output): - block_index = self.module_to_index[module] - # start the all-gather for next block - if block_index - 1 >= 0: - self._all_gather_block_weight(block_index - 1) - def _pre_backward_hook_for_block(block: nn.Module, grad_output): # import pdb; pdb.set_trace() block_index = self.block_to_index[block] diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index c7c1007..d226827 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -39,6 +39,14 @@ from .utils import compute_norm inf = math.inf logger = get_logger(__file__) +def print_memory(msg): + + if gpc.get_global_rank() == 0: + print(msg, flush=True) + print("memory allocated: ", torch.cuda.memory_allocated() / 1024 / 1024 / 1024, flush=True) + print("max memory allocated: ", torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, flush=True) + print("===========================================") + class HybridZeroOptimizer(BaseOptimizer): """ @@ -335,6 +343,7 @@ class HybridZeroOptimizer(BaseOptimizer): comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] comm_handle.wait() _param.grad += _grad + self._fstp_handler.reduce_scatter_handlers[key] = None bucket.reset_by_rank(rank) @@ -358,6 +367,7 @@ class HybridZeroOptimizer(BaseOptimizer): comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key] comm_handle.wait() _param.grad += _grad + self._fstp_handler.reduce_scatter_handlers[key] = None # reduce grad if self.skip_grad_reduce is False: @@ -565,6 +575,7 @@ class HybridZeroOptimizer(BaseOptimizer): # if not overlapping communication (no reduction hook is attached) # we need to manually reduce these gradients + print_memory("No 1") if not self._overlap_sync_grad: for group_id in range(len(self._fp16_param_groups)): for param in self._fp16_param_groups[group_id]: @@ -589,7 +600,7 @@ class HybridZeroOptimizer(BaseOptimizer): bucket.empty() self._bucket_in_progress = [] self._param_store.clear_grads_of_previous_reduced_params() - + print_memory("No 2") # compute norm for gradients in the last bucket total_norms = {} for group_id in range(self.num_param_groups): @@ -611,10 +622,12 @@ class HybridZeroOptimizer(BaseOptimizer): scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=pg) total_norms[group_name] = scaled_norm_tensor.item() - + print_memory("No 3") timer("sync_grad").start() self._sync_grad() timer("sync_grad").stop() + + print_memory("No 4") return self._step(closure=closure, norms=total_norms) @@ -661,7 +674,7 @@ class HybridZeroOptimizer(BaseOptimizer): self._grad_store._averaged_gradients = dict() self.zero_grad() return False, norms - + print_memory("No 5") # copy the grad of fp16 param to fp32 param single_grad_partition_groups = [] for group_id in range(self.num_param_groups): @@ -702,7 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer): single_grad_partition_groups.append(flat_fp32_avg_grads) device = self._fp32_flat_param_groups_of_current_rank[group_id].device self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) - + print_memory("No 6") # unscale and clip grads # get the global norm global_norm_groups = {} @@ -725,9 +738,12 @@ class HybridZeroOptimizer(BaseOptimizer): # For those ranks that are not assigned parameters, we just wait for other ranks # to send them updated their own parameters. if self.has_params: + print_memory("No 7") self.optim.step() + print_memory("No 8") # release the fp32 grad release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) + print_memory("No 9") # update fp16 partition updated by the current rank for group_id in range(len(self._fp16_param_groups)): if self.param_group_has_params[group_id]: @@ -736,17 +752,18 @@ class HybridZeroOptimizer(BaseOptimizer): ) fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) - + print_memory("No 10") torch.cuda.synchronize() with torch.cuda.stream(self._comm_bcast_stream): self.broadcast_params() - + timer("step").stop() # update gradients may not be needed here, because the sync_params function is used in initialization, # so synchronization is maintained for group_name, global_norm in global_norm_groups.items(): global_norm_groups[group_name] = global_norm / loss_scale + print_memory("No 11") return True, global_norm_groups def broadcast_params(self): diff --git a/train.py b/train.py index 139bac1..0a84f59 100644 --- a/train.py +++ b/train.py @@ -296,6 +296,8 @@ def main(args): if batch_count % 2 == 0: prof.step() + + torch.cuda.reset_peak_memory_stats() ckpt_manager.wait_async_upload_finish()