mirror of https://github.com/InternLM/InternLM
check grad norm
parent
61346c24f6
commit
ed1d9c3b7c
|
@ -28,6 +28,7 @@ from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
TOTAL_STEPS = 300
|
||||||
config = Config(
|
config = Config(
|
||||||
dict(
|
dict(
|
||||||
parallel=dict(
|
parallel=dict(
|
||||||
|
@ -42,7 +43,7 @@ config = Config(
|
||||||
micro_bsz=2,
|
micro_bsz=2,
|
||||||
pack_sample_into_one=False,
|
pack_sample_into_one=False,
|
||||||
min_length=50,
|
min_length=50,
|
||||||
total_steps=300,
|
total_steps=TOTAL_STEPS,
|
||||||
valid_micro_num=4,
|
valid_micro_num=4,
|
||||||
valid_every=300,
|
valid_every=300,
|
||||||
rampup_batch_size=None,
|
rampup_batch_size=None,
|
||||||
|
@ -105,7 +106,7 @@ config = Config(
|
||||||
cur_iter=-1,
|
cur_iter=-1,
|
||||||
),
|
),
|
||||||
lr_scheduler=dict(
|
lr_scheduler=dict(
|
||||||
total_steps=100,
|
total_steps=TOTAL_STEPS,
|
||||||
init_steps=0,
|
init_steps=0,
|
||||||
warmup_ratio=0.01,
|
warmup_ratio=0.01,
|
||||||
eta_min=1e-5,
|
eta_min=1e-5,
|
||||||
|
@ -217,6 +218,53 @@ def evaluate_on_val_dls(
|
||||||
return val_loss
|
return val_loss
|
||||||
|
|
||||||
|
|
||||||
|
def compute_trimmed_mean(value_list):
|
||||||
|
trim = int(0.05 * len(value_list))
|
||||||
|
trimmed_list = value_list[trim:-trim]
|
||||||
|
trimmed_mean = sum(trimmed_list) / len(trimmed_list)
|
||||||
|
return trimmed_mean
|
||||||
|
|
||||||
|
|
||||||
|
def check_grad_norm(grad_norm_list):
|
||||||
|
standard_grad_norm_list = torch.load(os.path.join(
|
||||||
|
os.environ["share_path"], "quailty_assurance/small_300step_norm_grad/grad_norm_list.pt"
|
||||||
|
))
|
||||||
|
|
||||||
|
standard_grad_norm_list = standard_grad_norm_list[-100:]
|
||||||
|
grad_norm_list = grad_norm_list[-100:]
|
||||||
|
standard_grad_norm_list.sort()
|
||||||
|
grad_norm_list.sort()
|
||||||
|
|
||||||
|
trimmed_mean1 = compute_trimmed_mean(standard_grad_norm_list)
|
||||||
|
trimmed_mean2 = compute_trimmed_mean(grad_norm_list)
|
||||||
|
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
|
||||||
|
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
|
||||||
|
|
||||||
|
logger.info(f"norm_mean: {tensor_trimmed_mean1}, {tensor_trimmed_mean2}")
|
||||||
|
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-1, atol=3e-1)
|
||||||
|
logger.info(f"grad norm check passed")
|
||||||
|
|
||||||
|
|
||||||
|
def check_meanLoss_val(all_loss, all_val):
|
||||||
|
loss_values1 = all_loss[0][-100:]
|
||||||
|
loss_values2 = all_loss[1][-100:]
|
||||||
|
loss_values1.sort()
|
||||||
|
loss_values2.sort()
|
||||||
|
|
||||||
|
trimmed_mean1 = compute_trimmed_mean(loss_values1)
|
||||||
|
trimmed_mean2 = compute_trimmed_mean(loss_values2)
|
||||||
|
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
|
||||||
|
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
|
||||||
|
|
||||||
|
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
|
||||||
|
logger.info(f"all_val: {all_val}")
|
||||||
|
|
||||||
|
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
|
||||||
|
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
|
||||||
|
|
||||||
|
logger.info(f"loss check passed")
|
||||||
|
|
||||||
|
|
||||||
def exam_loss(args):
|
def exam_loss(args):
|
||||||
# init
|
# init
|
||||||
rank, world_size, micro_num, micro_bsz = args
|
rank, world_size, micro_num, micro_bsz = args
|
||||||
|
@ -234,7 +282,7 @@ def exam_loss(args):
|
||||||
model = initialize_model()
|
model = initialize_model()
|
||||||
|
|
||||||
# initialize loss function
|
# initialize loss function
|
||||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=0)
|
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing)
|
||||||
|
|
||||||
# initialize the train and validation data loader
|
# initialize the train and validation data loader
|
||||||
train_dl, dataset_types = get_train_data_loader(num_worker=0)
|
train_dl, dataset_types = get_train_data_loader(num_worker=0)
|
||||||
|
@ -279,6 +327,7 @@ def exam_loss(args):
|
||||||
# transfer the train data loader into train data iterator
|
# transfer the train data loader into train data iterator
|
||||||
loss_list = []
|
loss_list = []
|
||||||
val_list = []
|
val_list = []
|
||||||
|
grad_norm_list = []
|
||||||
for batch_count in range(total_steps):
|
for batch_count in range(total_steps):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# load batch data
|
# load batch data
|
||||||
|
@ -312,7 +361,14 @@ def exam_loss(args):
|
||||||
logger.info(f"batch_count: {batch_count}, tgs: {tgs_origin}, loss: {loss}")
|
logger.info(f"batch_count: {batch_count}, tgs: {tgs_origin}, loss: {loss}")
|
||||||
|
|
||||||
# update parameters
|
# update parameters
|
||||||
trainer.step()
|
trainer_result = trainer.step()
|
||||||
|
assert trainer_result is not None
|
||||||
|
|
||||||
|
_, grad_norm_groups = trainer_result
|
||||||
|
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info(f"train_grad_norm_groups: {grad_norm_groups['0_default']}")
|
||||||
|
grad_norm_list.append(grad_norm_groups['0_default'])
|
||||||
|
|
||||||
# evaluate on validation data loaders
|
# evaluate on validation data loaders
|
||||||
if valid_every > 0 and batch_count > 0 and (batch_count + 1) % valid_every == 0:
|
if valid_every > 0 and batch_count > 0 and (batch_count + 1) % valid_every == 0:
|
||||||
|
@ -326,30 +382,12 @@ def exam_loss(args):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
check_grad_norm(grad_norm_list)
|
||||||
|
|
||||||
return rank, loss_list, val_list
|
return rank, loss_list, val_list
|
||||||
|
|
||||||
|
|
||||||
def check_meanLoss_val(all_loss, all_val):
|
|
||||||
loss_values1 = all_loss[0][-100:]
|
|
||||||
loss_values2 = all_loss[1][-100:]
|
|
||||||
loss_values1.sort()
|
|
||||||
loss_values2.sort()
|
|
||||||
|
|
||||||
trim = int(0.05 * len(loss_values1))
|
|
||||||
trimmed_loss_values1 = loss_values1[trim:-trim]
|
|
||||||
trimmed_loss_values2 = loss_values2[trim:-trim]
|
|
||||||
trimmed_mean1 = sum(trimmed_loss_values1) / len(trimmed_loss_values1)
|
|
||||||
trimmed_mean2 = sum(trimmed_loss_values2) / len(trimmed_loss_values2)
|
|
||||||
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
|
|
||||||
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
|
|
||||||
|
|
||||||
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
|
|
||||||
logger.info(f"all_val: {all_val}")
|
|
||||||
|
|
||||||
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
|
|
||||||
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_loss():
|
def test_loss():
|
||||||
ctx = mp.get_context("spawn")
|
ctx = mp.get_context("spawn")
|
||||||
all_loss = []
|
all_loss = []
|
||||||
|
|
Loading…
Reference in New Issue