check loss

pull/510/head
lijiaxing 2023-11-22 10:29:32 +08:00
parent 35aa093afe
commit 61346c24f6
1 changed files with 376 additions and 0 deletions

View File

@ -0,0 +1,376 @@
import multiprocessing as mp
import os
import random
import time
import numpy as np
import pytest
import torch
import torch.distributed as dist
from tqdm import tqdm
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.core.scheduler import SchedulerMetricHook
from internlm.initialize.launch import args_sanity_check
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.train import (
get_train_data_loader,
get_validation_data_loader,
initialize_model,
initialize_optimizer,
)
from internlm.utils.evaluation import switch_evaluation_no_pipeline_scheduler
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
config = Config(
dict(
parallel=dict(
zero1=dict(size=-1, fsdp=False),
pipeline=dict(size=1, interleaved_overlap=False),
sequence_parallel=False,
tensor=1,
),
data=dict(
seq_len=2048,
micro_num=4,
micro_bsz=2,
pack_sample_into_one=False,
min_length=50,
total_steps=300,
valid_micro_num=4,
valid_every=300,
rampup_batch_size=None,
diag_outlier_ratio=1.1,
train_folder=os.path.join(
os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/train"
),
valid_folder=os.path.join(
os.environ["share_path"], "quailty_assurance/0623_scratch_tokenized_filtered/val"
),
),
model=dict(
checkpoint=False,
num_attention_heads=16,
embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=4096,
num_layers=16,
mlp_ratio=8 / 3,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1,
),
model_type="INTERNLM",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
grad_scaler=dict(
fp16=dict(
initial_scale=2**16,
min_scale=1,
growth_interval=1000,
),
growth_factor=2,
backoff_factor=0.5,
max_scale=2**24,
hysteresis=2,
),
adam=dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
),
hybrid_zero_optimizer=dict(
overlap_sync_grad=True,
overlap_sync_param=True,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
),
beta2_scheduler=dict(
init_beta2=0.95,
c=0,
cur_iter=-1,
),
lr_scheduler=dict(
total_steps=100,
init_steps=0,
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
),
ckpt=dict(
enable_save_ckpt=False,
auto_resume=False,
),
loss=dict(
label_smoothing=0,
),
)
)
def build_environment(rank, world_size, config):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "33333"
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
args_sanity_check()
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def load_new_batch(train_dl, train_iter):
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_dl)
batch = next(train_iter)
return batch, train_iter
def evaluate_on_val_dls(
trainer,
val_dls,
):
torch.cuda.empty_cache()
trainer.eval()
verbose = gpc.is_rank_for_log()
data_cfg = gpc.config.data
for _, val_dl in val_dls.items():
if len(val_dl) == 0 and verbose:
continue
val_metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
)
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
val_loss = 0
val_idx = -1
for val_idx, batch in tqdm(
enumerate(val_dl),
desc="Val.",
total=len(val_dl),
position=1,
disable=not verbose,
leave=False,
):
with torch.inference_mode():
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
val_loss += loss.item()
assert val_idx != -1
dist.barrier()
if verbose and len(val_dl) != 0:
val_loss = val_loss / (val_idx + 1 + 1e-6)
trainer.train()
torch.cuda.empty_cache()
dist.barrier()
return val_loss
def exam_loss(args):
# init
rank, world_size, micro_num, micro_bsz = args
config.data.micro_num = micro_num
config.data.micro_bsz = micro_bsz
build_environment(rank, world_size, config)
total_steps = gpc.config.data.total_steps
valid_every = gpc.config.data.valid_every
# set seed
seed_all(1024)
# initialize model
model = initialize_model()
# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=0)
# initialize the train and validation data loader
train_dl, dataset_types = get_train_data_loader(num_worker=0)
val_dls = get_validation_data_loader()
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
dataset_types=dataset_types,
)
# initialize trainer
scheduler_hooks = [
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
),
),
]
trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks,
)
trainer.train()
train_iter = iter(train_dl)
# transfer the train data loader into train data iterator
loss_list = []
val_list = []
for batch_count in range(total_steps):
start_time = time.time()
# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter)
# zero the grads of parameters
trainer.zero_grad()
# process data
if batch[0].get("type_ids", None) is not None:
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
_, _, loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
loss_list.append(loss.item())
num_tokens_in_batch = batch[1].nelement()
tgs_origin = round(
num_tokens_in_batch
* gpc.get_world_size(ParallelMode.DATA)
/ gpc.get_world_size(ParallelMode.GLOBAL)
/ (time.time() - start_time),
2,
)
if rank == 0:
logger.info(f"batch_count: {batch_count}, tgs: {tgs_origin}, loss: {loss}")
# update parameters
trainer.step()
# evaluate on validation data loaders
if valid_every > 0 and batch_count > 0 and (batch_count + 1) % valid_every == 0:
val_result = evaluate_on_val_dls(
trainer=trainer,
val_dls=val_dls,
)
if val_result != 0:
val_list.append(val_result)
torch.cuda.empty_cache()
dist.barrier()
return rank, loss_list, val_list
def check_meanLoss_val(all_loss, all_val):
loss_values1 = all_loss[0][-100:]
loss_values2 = all_loss[1][-100:]
loss_values1.sort()
loss_values2.sort()
trim = int(0.05 * len(loss_values1))
trimmed_loss_values1 = loss_values1[trim:-trim]
trimmed_loss_values2 = loss_values2[trim:-trim]
trimmed_mean1 = sum(trimmed_loss_values1) / len(trimmed_loss_values1)
trimmed_mean2 = sum(trimmed_loss_values2) / len(trimmed_loss_values2)
tensor_trimmed_mean1 = torch.tensor(trimmed_mean1)
tensor_trimmed_mean2 = torch.tensor(trimmed_mean2)
logger.info(f"avg_value: {trimmed_mean1}, {trimmed_mean2}")
logger.info(f"all_val: {all_val}")
assert torch.allclose(tensor_trimmed_mean1, tensor_trimmed_mean2, rtol=3e-2, atol=3e-2)
assert torch.allclose(torch.tensor(all_val[0]), torch.tensor(all_val[1]), rtol=3e-2, atol=3e-2)
def test_loss():
ctx = mp.get_context("spawn")
all_loss = []
all_val = []
micro_num = 4
micro_bsz = 2
for train_round in range(2):
if train_round == 1:
micro_num, micro_bsz = micro_bsz, micro_num
with ctx.Pool(processes=8) as pool:
results = pool.map(
exam_loss,
[[rank, 8, micro_num, micro_bsz] for rank in range(8)],
)
all_loss.append(results[0][1])
all_val.append(results[0][2])
pool.close()
pool.join()
check_meanLoss_val(all_loss, all_val)
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_diff_num_bsz_loss.py"])