mirror of https://github.com/InternLM/InternLM
check grad norm
parent
61346c24f6
commit
ed1d9c3b7c
|
@ -28,6 +28,7 @@ from internlm.utils.logger import get_logger
|
|||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
TOTAL_STEPS = 300
|
||||
config = Config(
|
||||
dict(
|
||||
parallel=dict(
|
||||
|
@ -42,7 +43,7 @@ config = Config(
|
|||
micro_bsz=2,
|
||||
pack_sample_into_one=False,
|
||||
min_length=50,
|
||||
total_steps=300,
|
||||
total_steps=TOTAL_STEPS,
|
||||
valid_micro_num=4,
|
||||
valid_every=300,
|
||||
rampup_batch_size=None,
|
||||
|
@ -105,7 +106,7 @@ config = Config(
|
|||
cur_iter=-1,
|
||||
),
|
||||
lr_scheduler=dict(
|
||||
total_steps=100,
|
||||
total_steps=TOTAL_STEPS,
|
||||
init_steps=0,
|
||||
warmup_ratio=0.01,
|
||||
eta_min=1e-5,
|
||||
|
@ -217,6 +218,53 @@ def evaluate_on_val_dls(
|
|||
return val_loss
|
||||
|
||||
|
||||
def compute_trimmed_mean(value_list):
|
||||
trim = int(0.05 * len(value_list))
|
||||
trimmed_list = value_list[trim:-trim]
|
||||
trimmed_mean = sum(trimmed_list) / len(trimmed_list)
|
||||
return trimmed_mean
|
||||
|
||||
|
||||
def check_grad_norm(grad_norm_list):
|
||||
standard_grad_norm_list = torch.load(os.path.join(
|
||||
os.environ["share_path"], "quailty_assurance/small_300step_norm_grad/grad_norm_list.pt"
|
||||
))
|
||||
|
||||
standard_grad_norm_list = standard_grad_norm_list[-100:]
|
||||
grad_norm_list = grad_norm_list[-100:]
|
||||
standard_grad_norm_list.sort()
|
||||
grad_norm_list.sort()
|
||||
|
||||
trimmed_mean1 = compute_trimmed_mean(standard_grad_norm_list)
|
||||
trimmed_mean2 = compute_trimmed_mean(grad_norm_list)
|
||||
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
|
||||
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
|
||||
|
||||
logger.info(f"norm_mean: {tensor_trimmed_mean1}, {tensor_trimmed_mean2}")
|
||||
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-1, atol=3e-1)
|
||||
logger.info(f"grad norm check passed")
|
||||
|
||||
|
||||
def check_meanLoss_val(all_loss, all_val):
|
||||
loss_values1 = all_loss[0][-100:]
|
||||
loss_values2 = all_loss[1][-100:]
|
||||
loss_values1.sort()
|
||||
loss_values2.sort()
|
||||
|
||||
trimmed_mean1 = compute_trimmed_mean(loss_values1)
|
||||
trimmed_mean2 = compute_trimmed_mean(loss_values2)
|
||||
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
|
||||
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
|
||||
|
||||
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
|
||||
logger.info(f"all_val: {all_val}")
|
||||
|
||||
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
|
||||
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
|
||||
|
||||
logger.info(f"loss check passed")
|
||||
|
||||
|
||||
def exam_loss(args):
|
||||
# init
|
||||
rank, world_size, micro_num, micro_bsz = args
|
||||
|
@ -234,7 +282,7 @@ def exam_loss(args):
|
|||
model = initialize_model()
|
||||
|
||||
# initialize loss function
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=0)
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing)
|
||||
|
||||
# initialize the train and validation data loader
|
||||
train_dl, dataset_types = get_train_data_loader(num_worker=0)
|
||||
|
@ -279,6 +327,7 @@ def exam_loss(args):
|
|||
# transfer the train data loader into train data iterator
|
||||
loss_list = []
|
||||
val_list = []
|
||||
grad_norm_list = []
|
||||
for batch_count in range(total_steps):
|
||||
start_time = time.time()
|
||||
# load batch data
|
||||
|
@ -312,7 +361,14 @@ def exam_loss(args):
|
|||
logger.info(f"batch_count: {batch_count}, tgs: {tgs_origin}, loss: {loss}")
|
||||
|
||||
# update parameters
|
||||
trainer.step()
|
||||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
|
||||
_, grad_norm_groups = trainer_result
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"train_grad_norm_groups: {grad_norm_groups['0_default']}")
|
||||
grad_norm_list.append(grad_norm_groups['0_default'])
|
||||
|
||||
# evaluate on validation data loaders
|
||||
if valid_every > 0 and batch_count > 0 and (batch_count + 1) % valid_every == 0:
|
||||
|
@ -326,30 +382,12 @@ def exam_loss(args):
|
|||
torch.cuda.empty_cache()
|
||||
dist.barrier()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
check_grad_norm(grad_norm_list)
|
||||
|
||||
return rank, loss_list, val_list
|
||||
|
||||
|
||||
def check_meanLoss_val(all_loss, all_val):
|
||||
loss_values1 = all_loss[0][-100:]
|
||||
loss_values2 = all_loss[1][-100:]
|
||||
loss_values1.sort()
|
||||
loss_values2.sort()
|
||||
|
||||
trim = int(0.05 * len(loss_values1))
|
||||
trimmed_loss_values1 = loss_values1[trim:-trim]
|
||||
trimmed_loss_values2 = loss_values2[trim:-trim]
|
||||
trimmed_mean1 = sum(trimmed_loss_values1) / len(trimmed_loss_values1)
|
||||
trimmed_mean2 = sum(trimmed_loss_values2) / len(trimmed_loss_values2)
|
||||
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
|
||||
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
|
||||
|
||||
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
|
||||
logger.info(f"all_val: {all_val}")
|
||||
|
||||
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
|
||||
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
|
||||
|
||||
|
||||
def test_loss():
|
||||
ctx = mp.get_context("spawn")
|
||||
all_loss = []
|
||||
|
|
Loading…
Reference in New Issue