add valid_pack_mode

pull/541/head
877825076@qq.com 2023-12-14 14:40:12 +08:00
parent 136aa7c5a5
commit cd91e92bd7
1 changed files with 10 additions and 4 deletions

View File

@ -103,10 +103,16 @@ def evaluate_on_val_dls(
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]]
)
if data_cfg.get("valid_pack_mode", None) is None or data_cfg.valid_pack_mode is False:
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]]
)
else:
num_microbatches = total_val_bsz
tensor_shape = torch.Size(
[1, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]]
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer,