fix evaluation bug in pp

pull/407/head
yingtongxiong 2023-10-09 20:04:27 +08:00
parent 54e561665e
commit 144731c35c
2 changed files with 10 additions and 4 deletions

View File

@ -283,7 +283,7 @@ def args_sanity_check():
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
if gpc.config.parallel["tensor"].get("mode", None) is 'fstp':
if gpc.config.parallel["tensor"].get("mode", None) == 'fstp':
assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True."
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy

View File

@ -106,9 +106,15 @@ def evaluate_on_val_dls(
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
if gpc.config.parallel['tensor']['mode'] == 'fstp':
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
)
else:
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer,