mirror of https://github.com/InternLM/InternLM
fix comments
parent
fdbdfcff34
commit
9780c44917
|
@ -115,7 +115,8 @@ def args_sanity_check():
|
|||
data.packed_length == data.seq_len * data.micro_bsz
|
||||
), "'packed_length' must be equal to 'seq_len * micro_bsz'"
|
||||
else:
|
||||
assert data.packed_length is not None, "'packed_length' must be given a value"
|
||||
assert data.get("packed_length", None) is not None, "'packed_length' must be given a value"
|
||||
assert data.packed_length % data.seq_len == 0, "'packed_length' must be divisible by 'seq_len'"
|
||||
|
||||
if "micro_num" not in data:
|
||||
data._add_item("micro_num", 1)
|
||||
|
|
|
@ -301,9 +301,7 @@ def get_validation_data_loader(
|
|||
else:
|
||||
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
|
||||
# otherwise too much data may be dropped
|
||||
micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
|
||||
batch_size = min(data_cfg.valid_micro_num * micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA))
|
||||
batch_size = batch_size // micro_bsz * micro_bsz
|
||||
batch_size = min(data_cfg.valid_micro_num, len(ds) // gpc.get_world_size(ParallelMode.DATA))
|
||||
|
||||
if batch_size == 0 and gpc.is_rank_for_log():
|
||||
logger.info(f"skip validate {val_name}.")
|
||||
|
|
|
@ -70,7 +70,6 @@ def evaluate_on_val_dls(
|
|||
torch.cuda.empty_cache()
|
||||
trainer.eval()
|
||||
verbose = gpc.is_rank_for_log()
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
for val_name, val_dl in val_dls.items():
|
||||
if not streaming and len(val_dl) == 0 and verbose:
|
||||
|
@ -96,12 +95,9 @@ def evaluate_on_val_dls(
|
|||
):
|
||||
moe_loss = None
|
||||
with torch.inference_mode():
|
||||
micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
|
||||
if gpc.is_using_pp():
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // micro_bsz
|
||||
tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE])
|
||||
num_microbatches = len(batch[1])
|
||||
tensor_shape = torch.Size([1, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE])
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
|
@ -119,8 +115,7 @@ def evaluate_on_val_dls(
|
|||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
else:
|
||||
total_val_bsz = len(batch[1])
|
||||
grad_accum_size = total_val_bsz // micro_bsz
|
||||
grad_accum_size = len(batch[1])
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
grad_accum_size=grad_accum_size,
|
||||
|
|
Loading…
Reference in New Issue