fix comments

pull/517/head
gaoyang07 2023-11-25 23:34:18 +08:00
parent fdbdfcff34
commit 9780c44917
3 changed files with 6 additions and 12 deletions

View File

@ -115,7 +115,8 @@ def args_sanity_check():
data.packed_length == data.seq_len * data.micro_bsz
), "'packed_length' must be equal to 'seq_len * micro_bsz'"
else:
assert data.packed_length is not None, "'packed_length' must be given a value"
assert data.get("packed_length", None) is not None, "'packed_length' must be given a value"
assert data.packed_length % data.seq_len == 0, "'packed_length' must be divisible by 'seq_len'"
if "micro_num" not in data:
data._add_item("micro_num", 1)

View File

@ -301,9 +301,7 @@ def get_validation_data_loader(
else:
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
# otherwise too much data may be dropped
micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
batch_size = min(data_cfg.valid_micro_num * micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA))
batch_size = batch_size // micro_bsz * micro_bsz
batch_size = min(data_cfg.valid_micro_num, len(ds) // gpc.get_world_size(ParallelMode.DATA))
if batch_size == 0 and gpc.is_rank_for_log():
logger.info(f"skip validate {val_name}.")

View File

@ -70,7 +70,6 @@ def evaluate_on_val_dls(
torch.cuda.empty_cache()
trainer.eval()
verbose = gpc.is_rank_for_log()
data_cfg = gpc.config.data
for val_name, val_dl in val_dls.items():
if not streaming and len(val_dl) == 0 and verbose:
@ -96,12 +95,9 @@ def evaluate_on_val_dls(
):
moe_loss = None
with torch.inference_mode():
micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
assert total_val_bsz % micro_bsz == 0
num_microbatches = total_val_bsz // micro_bsz
tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE])
num_microbatches = len(batch[1])
tensor_shape = torch.Size([1, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE])
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
@ -119,8 +115,7 @@ def evaluate_on_val_dls(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
total_val_bsz = len(batch[1])
grad_accum_size = total_val_bsz // micro_bsz
grad_accum_size = len(batch[1])
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,