mirror of https://github.com/InternLM/InternLM
fix evaluation bug in pp
parent
54e561665e
commit
144731c35c
|
@ -283,7 +283,7 @@ def args_sanity_check():
|
||||||
if gpc.config.parallel["tensor"].get("mode", None) is None:
|
if gpc.config.parallel["tensor"].get("mode", None) is None:
|
||||||
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
|
gpc.config.parallel["tensor"]["mode"] = "origin_tp"
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"].get("mode", None) is 'fstp':
|
if gpc.config.parallel["tensor"].get("mode", None) == 'fstp':
|
||||||
assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True."
|
assert gpc.config.parallel.sequence_parallel is True, "when the tp_mode is fstp, the sequence_parallel should be True."
|
||||||
|
|
||||||
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
|
||||||
|
|
|
@ -106,6 +106,12 @@ def evaluate_on_val_dls(
|
||||||
total_val_bsz = len(batch[1])
|
total_val_bsz = len(batch[1])
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||||
|
if gpc.config.parallel['tensor']['mode'] == 'fstp':
|
||||||
|
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
|
tensor_shape = torch.Size(
|
||||||
|
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1] // sequence_world_size, gpc.config.HIDDEN_SIZE]
|
||||||
|
)
|
||||||
|
else:
|
||||||
tensor_shape = torch.Size(
|
tensor_shape = torch.Size(
|
||||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue