diff --git a/tests/test_training/test_diff_num_bsz_loss.py b/tests/test_training/test_diff_num_bsz_loss.py index 6a25e53..04a6faa 100644 --- a/tests/test_training/test_diff_num_bsz_loss.py +++ b/tests/test_training/test_diff_num_bsz_loss.py @@ -28,6 +28,7 @@ from internlm.utils.logger import get_logger logger = get_logger(__file__) +TOTAL_STEPS = 300 config = Config( dict( parallel=dict( @@ -42,7 +43,7 @@ config = Config( micro_bsz=2, pack_sample_into_one=False, min_length=50, - total_steps=300, + total_steps=TOTAL_STEPS, valid_micro_num=4, valid_every=300, rampup_batch_size=None, @@ -105,7 +106,7 @@ config = Config( cur_iter=-1, ), lr_scheduler=dict( - total_steps=100, + total_steps=TOTAL_STEPS, init_steps=0, warmup_ratio=0.01, eta_min=1e-5, @@ -217,6 +218,53 @@ def evaluate_on_val_dls( return val_loss +def compute_trimmed_mean(value_list): + trim = int(0.05 * len(value_list)) + trimmed_list = value_list[trim:-trim] + trimmed_mean = sum(trimmed_list) / len(trimmed_list) + return trimmed_mean + + +def check_grad_norm(grad_norm_list): + standard_grad_norm_list = torch.load(os.path.join( + os.environ["share_path"], "quailty_assurance/small_300step_norm_grad/grad_norm_list.pt" + )) + + standard_grad_norm_list = standard_grad_norm_list[-100:] + grad_norm_list = grad_norm_list[-100:] + standard_grad_norm_list.sort() + grad_norm_list.sort() + + trimmed_mean1 = compute_trimmed_mean(standard_grad_norm_list) + trimmed_mean2 = compute_trimmed_mean(grad_norm_list) + tensor_trimmed_mean1 = torch.tensor(trimmed_mean1) + tensor_trimmed_mean2 = torch.tensor(trimmed_mean2) + + logger.info(f"norm_mean: {tensor_trimmed_mean1}, {tensor_trimmed_mean2}") + assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-1, atol=3e-1) + logger.info(f"grad norm check passed") + + +def check_meanLoss_val(all_loss, all_val): + loss_values1 = all_loss[0][-100:] + loss_values2 = all_loss[1][-100:] + loss_values1.sort() + loss_values2.sort() + + trimmed_mean1 = compute_trimmed_mean(loss_values1) + trimmed_mean2 = compute_trimmed_mean(loss_values2) + tensor_trimmed_mean1 = torch.tensor(trimmed_mean1) + tensor_trimmed_mean2 = torch.tensor(trimmed_mean2) + + logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}") + logger.info(f"all_val: {all_val}") + + assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2) + assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2) + + logger.info(f"loss check passed") + + def exam_loss(args): # init rank, world_size, micro_num, micro_bsz = args @@ -234,7 +282,7 @@ def exam_loss(args): model = initialize_model() # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=0) + criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) # initialize the train and validation data loader train_dl, dataset_types = get_train_data_loader(num_worker=0) @@ -279,6 +327,7 @@ def exam_loss(args): # transfer the train data loader into train data iterator loss_list = [] val_list = [] + grad_norm_list = [] for batch_count in range(total_steps): start_time = time.time() # load batch data @@ -312,7 +361,14 @@ def exam_loss(args): logger.info(f"batch_count: {batch_count}, tgs: {tgs_origin}, loss: {loss}") # update parameters - trainer.step() + trainer_result = trainer.step() + assert trainer_result is not None + + _, grad_norm_groups = trainer_result + + if gpc.is_rank_for_log(): + logger.info(f"train_grad_norm_groups: {grad_norm_groups['0_default']}") + grad_norm_list.append(grad_norm_groups['0_default']) # evaluate on validation data loaders if valid_every > 0 and batch_count > 0 and (batch_count + 1) % valid_every == 0: @@ -325,31 +381,13 @@ def exam_loss(args): torch.cuda.empty_cache() dist.barrier() - + + if gpc.is_rank_for_log(): + check_grad_norm(grad_norm_list) + return rank, loss_list, val_list -def check_meanLoss_val(all_loss, all_val): - loss_values1 = all_loss[0][-100:] - loss_values2 = all_loss[1][-100:] - loss_values1.sort() - loss_values2.sort() - - trim = int(0.05 * len(loss_values1)) - trimmed_loss_values1 = loss_values1[trim:-trim] - trimmed_loss_values2 = loss_values2[trim:-trim] - trimmed_mean1 = sum(trimmed_loss_values1) / len(trimmed_loss_values1) - trimmed_mean2 = sum(trimmed_loss_values2) / len(trimmed_loss_values2) - tensor_trimmed_mean1 = torch.tensor(trimmed_mean1) - tensor_trimmed_mean2 = torch.tensor(trimmed_mean2) - - logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}") - logger.info(f"all_val: {all_val}") - - assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2) - assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2) - - def test_loss(): ctx = mp.get_context("spawn") all_loss = []