check grad norm

pull/510/head
lijiaxing 2023-11-23 14:16:48 +08:00
parent 61346c24f6
commit ed1d9c3b7c
1 changed files with 64 additions and 26 deletions

View File

@ -28,6 +28,7 @@ from internlm.utils.logger import get_logger
logger = get_logger(__file__)
TOTAL_STEPS = 300
config = Config(
dict(
parallel=dict(
@ -42,7 +43,7 @@ config = Config(
micro_bsz=2,
pack_sample_into_one=False,
min_length=50,
total_steps=300,
total_steps=TOTAL_STEPS,
valid_micro_num=4,
valid_every=300,
rampup_batch_size=None,
@ -105,7 +106,7 @@ config = Config(
cur_iter=-1,
),
lr_scheduler=dict(
total_steps=100,
total_steps=TOTAL_STEPS,
init_steps=0,
warmup_ratio=0.01,
eta_min=1e-5,
@ -217,6 +218,53 @@ def evaluate_on_val_dls(
return val_loss
def compute_trimmed_mean(value_list):
trim = int(0.05 * len(value_list))
trimmed_list = value_list[trim:-trim]
trimmed_mean = sum(trimmed_list) / len(trimmed_list)
return trimmed_mean
def check_grad_norm(grad_norm_list):
standard_grad_norm_list = torch.load(os.path.join(
os.environ["share_path"], "quailty_assurance/small_300step_norm_grad/grad_norm_list.pt"
))
standard_grad_norm_list = standard_grad_norm_list[-100:]
grad_norm_list = grad_norm_list[-100:]
standard_grad_norm_list.sort()
grad_norm_list.sort()
trimmed_mean1 = compute_trimmed_mean(standard_grad_norm_list)
trimmed_mean2 = compute_trimmed_mean(grad_norm_list)
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
logger.info(f"norm_mean: {tensor_trimmed_mean1}, {tensor_trimmed_mean2}")
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-1, atol=3e-1)
logger.info(f"grad norm check passed")
def check_meanLoss_val(all_loss, all_val):
loss_values1 = all_loss[0][-100:]
loss_values2 = all_loss[1][-100:]
loss_values1.sort()
loss_values2.sort()
trimmed_mean1 = compute_trimmed_mean(loss_values1)
trimmed_mean2 = compute_trimmed_mean(loss_values2)
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
logger.info(f"all_val: {all_val}")
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
logger.info(f"loss check passed")
def exam_loss(args):
# init
rank, world_size, micro_num, micro_bsz = args
@ -234,7 +282,7 @@ def exam_loss(args):
model = initialize_model()
# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=0)
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing)
# initialize the train and validation data loader
train_dl, dataset_types = get_train_data_loader(num_worker=0)
@ -279,6 +327,7 @@ def exam_loss(args):
# transfer the train data loader into train data iterator
loss_list = []
val_list = []
grad_norm_list = []
for batch_count in range(total_steps):
start_time = time.time()
# load batch data
@ -312,7 +361,14 @@ def exam_loss(args):
logger.info(f"batch_count: {batch_count}, tgs: {tgs_origin}, loss: {loss}")
# update parameters
trainer.step()
trainer_result = trainer.step()
assert trainer_result is not None
_, grad_norm_groups = trainer_result
if gpc.is_rank_for_log():
logger.info(f"train_grad_norm_groups: {grad_norm_groups['0_default']}")
grad_norm_list.append(grad_norm_groups['0_default'])
# evaluate on validation data loaders
if valid_every > 0 and batch_count > 0 and (batch_count + 1) % valid_every == 0:
@ -325,31 +381,13 @@ def exam_loss(args):
torch.cuda.empty_cache()
dist.barrier()
if gpc.is_rank_for_log():
check_grad_norm(grad_norm_list)
return rank, loss_list, val_list
def check_meanLoss_val(all_loss, all_val):
loss_values1 = all_loss[0][-100:]
loss_values2 = all_loss[1][-100:]
loss_values1.sort()
loss_values2.sort()
trim = int(0.05 * len(loss_values1))
trimmed_loss_values1 = loss_values1[trim:-trim]
trimmed_loss_values2 = loss_values2[trim:-trim]
trimmed_mean1 = sum(trimmed_loss_values1) / len(trimmed_loss_values1)
trimmed_mean2 = sum(trimmed_loss_values2) / len(trimmed_loss_values2)
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
logger.info(f"all_val: {all_val}")
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
def test_loss():
ctx = mp.get_context("spawn")
all_loss = []