2023-08-02 11:03:59 +00:00
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
from internlm.core.context import ParallelMode
|
|
|
|
from internlm.core.context import global_context as gpc
|
2023-08-03 03:48:12 +00:00
|
|
|
from internlm.core.scheduler import SchedulerMetricHook
|
2023-08-08 03:18:15 +00:00
|
|
|
from internlm.model.metrics import AccPerplex
|
2023-08-02 11:03:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
2023-08-03 03:48:12 +00:00
|
|
|
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
|
2023-08-02 11:03:59 +00:00
|
|
|
if not gpc.is_using_pp():
|
|
|
|
prev_data_process_func = trainer.schedule.data_process_func
|
|
|
|
prev_grad_accum_size = trainer.schedule._grad_accum_size
|
|
|
|
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
|
2023-08-03 03:48:12 +00:00
|
|
|
prev_metric_hooks = trainer.schedule._hooks
|
2023-08-02 11:03:59 +00:00
|
|
|
try:
|
|
|
|
trainer.schedule.data_process_func = None
|
|
|
|
trainer.schedule._grad_accum_size = grad_accum_size
|
|
|
|
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
|
2023-08-03 03:48:12 +00:00
|
|
|
trainer.schedule._hooks = metric_hook_list
|
2023-08-02 11:03:59 +00:00
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
trainer.schedule.data_process_func = prev_data_process_func
|
|
|
|
trainer.schedule._grad_accum_size = prev_grad_accum_size
|
|
|
|
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
|
2023-08-03 03:48:12 +00:00
|
|
|
trainer.schedule._hooks = prev_metric_hooks
|
2023-08-02 11:03:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
2023-08-03 03:48:12 +00:00
|
|
|
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list):
|
2023-08-02 11:03:59 +00:00
|
|
|
if gpc.is_using_pp():
|
|
|
|
pre_data_process_func = trainer.schedule.data_process_func
|
|
|
|
prev_num_microbatches = trainer.schedule.num_microbatches
|
|
|
|
prev_tensor_shape = trainer.schedule.tensor_shape
|
2023-08-03 03:48:12 +00:00
|
|
|
prev_metric_hooks = trainer.schedule._hooks
|
2023-08-02 11:03:59 +00:00
|
|
|
try:
|
|
|
|
trainer.schedule.data_process_func = None
|
|
|
|
trainer.schedule.num_microbatches = num_microbatches
|
|
|
|
trainer.schedule.tensor_shape = tensor_shape
|
2023-08-03 03:48:12 +00:00
|
|
|
trainer.schedule._hooks = metric_hook_list
|
2023-08-02 11:03:59 +00:00
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
trainer.schedule.data_process_func = pre_data_process_func
|
|
|
|
trainer.schedule.num_microbatches = prev_num_microbatches
|
|
|
|
trainer.schedule.tensor_shape = prev_tensor_shape
|
2023-08-03 03:48:12 +00:00
|
|
|
trainer.schedule._hooks = prev_metric_hooks
|
2023-08-02 11:03:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
def evaluate_on_val_dls(
|
|
|
|
trainer,
|
|
|
|
val_dls,
|
|
|
|
writer,
|
|
|
|
logger,
|
|
|
|
step_count,
|
|
|
|
update_panel: bool = False,
|
|
|
|
):
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
trainer.eval()
|
|
|
|
verbose = gpc.is_rank_for_log()
|
|
|
|
data_cfg = gpc.config.data
|
|
|
|
|
|
|
|
for val_name, val_dl in val_dls.items():
|
|
|
|
if len(val_dl) == 0 and verbose:
|
|
|
|
logger.info(f"Validation dataset: {val_name} is empty")
|
|
|
|
continue
|
|
|
|
|
|
|
|
val_metric = AccPerplex(
|
|
|
|
device=torch.cuda.current_device(),
|
|
|
|
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
|
|
|
dp_pg=gpc.get_group(ParallelMode.DATA),
|
|
|
|
)
|
2023-08-03 03:48:12 +00:00
|
|
|
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
|
|
|
|
|
2023-08-02 11:03:59 +00:00
|
|
|
val_loss = 0
|
|
|
|
val_idx = -1
|
|
|
|
for val_idx, batch in tqdm(
|
|
|
|
enumerate(val_dl),
|
|
|
|
desc="Val.",
|
|
|
|
total=len(val_dl),
|
|
|
|
position=1,
|
|
|
|
disable=not verbose,
|
|
|
|
leave=False,
|
|
|
|
):
|
|
|
|
with torch.inference_mode():
|
|
|
|
if gpc.is_using_pp():
|
|
|
|
total_val_bsz = len(batch[1])
|
|
|
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
|
|
|
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
2023-08-08 03:18:15 +00:00
|
|
|
tensor_shape = torch.Size(
|
|
|
|
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
|
|
|
)
|
2023-08-02 11:03:59 +00:00
|
|
|
|
|
|
|
with switch_evaluation_pipeline_scheduler(
|
2023-08-03 03:48:12 +00:00
|
|
|
trainer=trainer,
|
|
|
|
num_microbatches=num_microbatches,
|
|
|
|
tensor_shape=tensor_shape,
|
|
|
|
metric_hook_list=[val_sche_metric_hook],
|
2023-08-02 11:03:59 +00:00
|
|
|
):
|
|
|
|
_, _, loss = trainer.execute_schedule(
|
2023-08-03 03:48:12 +00:00
|
|
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
2023-08-02 11:03:59 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
total_val_bsz = len(batch[1])
|
|
|
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
|
|
|
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
|
|
|
grad_accum_batch_size = data_cfg.micro_bsz
|
|
|
|
with switch_evaluation_no_pipeline_scheduler(
|
2023-08-03 03:48:12 +00:00
|
|
|
trainer=trainer,
|
|
|
|
grad_accum_size=grad_accum_size,
|
|
|
|
grad_accum_batch_size=grad_accum_batch_size,
|
|
|
|
metric_hook_list=[val_sche_metric_hook],
|
2023-08-02 11:03:59 +00:00
|
|
|
):
|
|
|
|
_, _, loss = trainer.execute_schedule(
|
2023-08-03 03:48:12 +00:00
|
|
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
2023-08-02 11:03:59 +00:00
|
|
|
)
|
|
|
|
if verbose:
|
|
|
|
val_loss += loss.item()
|
|
|
|
|
|
|
|
assert val_idx != -1
|
|
|
|
dist.barrier()
|
|
|
|
|
2023-08-03 03:48:12 +00:00
|
|
|
val_res = val_metric.get_metric()
|
2023-08-02 11:03:59 +00:00
|
|
|
if verbose and len(val_dl) != 0:
|
|
|
|
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
|
|
|
infos = {
|
2023-08-03 03:48:12 +00:00
|
|
|
"step": step_count,
|
2023-08-02 11:03:59 +00:00
|
|
|
f"val/{val_name}_loss": val_loss,
|
|
|
|
f"val/{val_name}_acc": val_res["acc"],
|
|
|
|
f"val/{val_name}_plex": val_res["perplexity"],
|
|
|
|
}
|
2023-08-03 03:48:12 +00:00
|
|
|
|
2023-08-02 11:03:59 +00:00
|
|
|
for key, value in infos.items():
|
|
|
|
writer.add_scalar(key=key, value=value, step=step_count)
|
2023-08-03 03:48:12 +00:00
|
|
|
|
2023-08-02 11:03:59 +00:00
|
|
|
if update_panel:
|
|
|
|
logger.info(
|
|
|
|
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
2023-08-03 03:48:12 +00:00
|
|
|
extra={
|
|
|
|
"step": step_count,
|
|
|
|
"val_loss": val_loss,
|
|
|
|
"val_acc": val_res["acc"],
|
|
|
|
"val_perplexity": val_res["perplexity"],
|
|
|
|
},
|
2023-08-02 11:03:59 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
logger.info(
|
|
|
|
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
|
|
|
|
)
|
|
|
|
|
|
|
|
trainer.train()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
dist.barrier()
|
2023-08-07 08:42:52 +00:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def switch_sequence_parallel_mode():
|
|
|
|
prev_mode = gpc.config.model.sequence_parallel
|
|
|
|
try:
|
|
|
|
gpc.config.model.sequence_parallel = False
|
|
|
|
yield
|
|
|
|
finally:
|
2023-08-08 03:18:15 +00:00
|
|
|
gpc.config.model.sequence_parallel = prev_mode
|