mirror of https://github.com/InternLM/InternLM
add valid_pack_mode
parent
136aa7c5a5
commit
cd91e92bd7
|
@ -103,10 +103,16 @@ def evaluate_on_val_dls(
|
|||
if gpc.is_using_pp():
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]]
|
||||
)
|
||||
if data_cfg.get("valid_pack_mode", None) is None or data_cfg.valid_pack_mode is False:
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]]
|
||||
)
|
||||
else:
|
||||
num_microbatches = total_val_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[1, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]]
|
||||
)
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
|
|
Loading…
Reference in New Issue