Feat(QA): Check loss when swapping micro_num and micro_bsz && Check grad norm (#510)

* unitest_only_forward

* memory_test

* doc fix

* doc fix

* check loss

* check grad norm

* check grad norm
pull/514/head
jiaxingli 2023-11-24 12:05:14 +08:00 committed by GitHub
parent 0d3811c029
commit b59641715a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 414 additions and 0 deletions

View File

@ -0,0 +1,414 @@
import multiprocessing as mp
import os
import random
import time
import numpy as np
import pytest
import torch
import torch.distributed as dist
from tqdm import tqdm
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.core.scheduler import SchedulerMetricHook
from internlm.initialize.launch import args_sanity_check
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.train import (
get_train_data_loader,
get_validation_data_loader,
initialize_model,
initialize_optimizer,
)
from internlm.utils.evaluation import switch_evaluation_no_pipeline_scheduler
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
TOTAL_STEPS = 300
config = Config(
dict(
parallel=dict(
zero1=dict(size=-1, fsdp=False),
pipeline=dict(size=1, interleaved_overlap=False),
sequence_parallel=False,
tensor=1,
),
data=dict(
seq_len=2048,
micro_num=4,
micro_bsz=2,
pack_sample_into_one=False,
min_length=50,
total_steps=TOTAL_STEPS,
valid_micro_num=4,
valid_every=300,
rampup_batch_size=None,
diag_outlier_ratio=1.1,
train_folder=os.path.join(
os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/train"
),
valid_folder=os.path.join(
os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/val"
),
),
model=dict(
checkpoint=False,
num_attention_heads=16,
embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=4096,
num_layers=16,
mlp_ratio=8 / 3,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1,
),
model_type="INTERNLM",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
grad_scaler=dict(
fp16=dict(
initial_scale=2**16,
min_scale=1,
growth_interval=1000,
),
growth_factor=2,
backoff_factor=0.5,
max_scale=2**24,
hysteresis=2,
),
adam=dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
),
hybrid_zero_optimizer=dict(
overlap_sync_grad=True,
overlap_sync_param=True,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
),
beta2_scheduler=dict(
init_beta2=0.95,
c=0,
cur_iter=-1,
),
lr_scheduler=dict(
total_steps=TOTAL_STEPS,
init_steps=0,
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
),
ckpt=dict(
enable_save_ckpt=False,
auto_resume=False,
),
loss=dict(
label_smoothing=0,
),
)
)
def build_environment(rank, world_size, config):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "33333"
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
args_sanity_check()
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def load_new_batch(train_dl, train_iter):
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_dl)
batch = next(train_iter)
return batch, train_iter
def evaluate_on_val_dls(
trainer,
val_dls,
):
torch.cuda.empty_cache()
trainer.eval()
verbose = gpc.is_rank_for_log()
data_cfg = gpc.config.data
for _, val_dl in val_dls.items():
if len(val_dl) == 0 and verbose:
continue
val_metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
)
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
val_loss = 0
val_idx = -1
for val_idx, batch in tqdm(
enumerate(val_dl),
desc="Val.",
total=len(val_dl),
position=1,
disable=not verbose,
leave=False,
):
with torch.inference_mode():
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
val_loss += loss.item()
assert val_idx != -1
dist.barrier()
if verbose and len(val_dl) != 0:
val_loss = val_loss / (val_idx + 1 + 1e-6)
trainer.train()
torch.cuda.empty_cache()
dist.barrier()
return val_loss
def compute_trimmed_mean(value_list):
trim = int(0.05 * len(value_list))
trimmed_list = value_list[trim:-trim]
trimmed_mean = sum(trimmed_list) / len(trimmed_list)
return trimmed_mean
def check_grad_norm(grad_norm_list):
standard_grad_norm_list = torch.load(os.path.join(
os.environ["share_path"], "quailty_assurance/small_300step_norm_grad/grad_norm_list.pt"
))
standard_grad_norm_list = standard_grad_norm_list[-100:]
grad_norm_list = grad_norm_list[-100:]
standard_grad_norm_list.sort()
grad_norm_list.sort()
trimmed_mean1 = compute_trimmed_mean(standard_grad_norm_list)
trimmed_mean2 = compute_trimmed_mean(grad_norm_list)
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
logger.info(f"norm_mean: {tensor_trimmed_mean1}, {tensor_trimmed_mean2}")
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-1, atol=3e-1)
logger.info(f"grad norm check passed")
def check_meanLoss_val(all_loss, all_val):
loss_values1 = all_loss[0][-100:]
loss_values2 = all_loss[1][-100:]
loss_values1.sort()
loss_values2.sort()
trimmed_mean1 = compute_trimmed_mean(loss_values1)
trimmed_mean2 = compute_trimmed_mean(loss_values2)
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
logger.info(f"all_val: {all_val}")
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
logger.info(f"loss check passed")
def exam_loss(args):
# init
rank, world_size, micro_num, micro_bsz = args
config.data.micro_num = micro_num
config.data.micro_bsz = micro_bsz
build_environment(rank, world_size, config)
total_steps = gpc.config.data.total_steps
valid_every = gpc.config.data.valid_every
# set seed
seed_all(1024)
# initialize model
model = initialize_model()
# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing)
# initialize the train and validation data loader
train_dl, dataset_types = get_train_data_loader(num_worker=0)
val_dls = get_validation_data_loader()
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
dataset_types=dataset_types,
)
# initialize trainer
scheduler_hooks = [
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
),
),
]
trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks,
)
trainer.train()
train_iter = iter(train_dl)
# transfer the train data loader into train data iterator
loss_list = []
val_list = []
grad_norm_list = []
for batch_count in range(total_steps):
start_time = time.time()
# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter)
# zero the grads of parameters
trainer.zero_grad()
# process data
if batch[0].get("type_ids", None) is not None:
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
_, _, loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
loss_list.append(loss.item())
num_tokens_in_batch = batch[1].nelement()
tgs_origin = round(
num_tokens_in_batch
* gpc.get_world_size(ParallelMode.DATA)
/ gpc.get_world_size(ParallelMode.GLOBAL)
/ (time.time() - start_time),
2,
)
if rank == 0:
logger.info(f"batch_count: {batch_count}, tgs: {tgs_origin}, loss: {loss}")
# update parameters
trainer_result = trainer.step()
assert trainer_result is not None
_, grad_norm_groups = trainer_result
if gpc.is_rank_for_log():
logger.info(f"train_grad_norm_groups: {grad_norm_groups['0_default']}")
grad_norm_list.append(grad_norm_groups['0_default'])
# evaluate on validation data loaders
if valid_every > 0 and batch_count > 0 and (batch_count + 1) % valid_every == 0:
val_result = evaluate_on_val_dls(
trainer=trainer,
val_dls=val_dls,
)
if val_result != 0:
val_list.append(val_result)
torch.cuda.empty_cache()
dist.barrier()
if gpc.is_rank_for_log():
check_grad_norm(grad_norm_list)
return rank, loss_list, val_list
def test_loss():
ctx = mp.get_context("spawn")
all_loss = []
all_val = []
micro_num = 4
micro_bsz = 2
for train_round in range(2):
if train_round == 1:
micro_num, micro_bsz = micro_bsz, micro_num
with ctx.Pool(processes=8) as pool:
results = pool.map(
exam_loss,
[[rank, 8, micro_num, micro_bsz] for rank in range(8)],
)
all_loss.append(results[0][1])
all_val.append(results[0][2])
pool.close()
pool.join()
check_meanLoss_val(all_loss, all_val)
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_diff_num_bsz_loss.py"])