diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py index 9110ca2..2ad10c0 100644 --- a/tests/test_data/test_batch_sampler.py +++ b/tests/test_data/test_batch_sampler.py @@ -64,20 +64,25 @@ def do_warmup(args): packed_length = micro_bsz * sql for i in range(init_config.data.total_steps): batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) - input_shape = batch[0]["type_ids"].shape - tokens_num = np.prod(input_shape) + type_ids_shape = batch[0]["type_ids"].shape + input_ids_shape = batch[0]["input_ids"].shape + tokens_num = np.prod(input_ids_shape) + + # If not use fa, 'type_ids' is unpcaked when load_new_batch is calling. + # However, 'input_ids' is unpcaked in pp/nopp engine. if not init_config.model.use_flash_attn: - if answer[i] > 1: - assert input_shape == torch.Size( - [answer[i], micro_bsz, sql] - ), f"iter:{i}, {input_shape} != {[answer[i], micro_bsz, sql]}" - else: - assert input_shape == torch.Size([micro_bsz, sql]), f"iter:{i}, {input_shape} != {[micro_bsz, sql]}" + assert type_ids_shape == torch.Size( + [answer[i], micro_bsz, sql] + ), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], micro_bsz, sql])}" else: - assert input_shape == torch.Size( + assert type_ids_shape == torch.Size( [answer[i], packed_length] - ), f"iter:{i}, {input_shape} != {torch.Size([answer[i], packed_length])}" + ), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], packed_length])}" + + assert input_ids_shape == torch.Size( + [answer[i], packed_length] + ), f"iter:{i}, input_ids_shape: {input_ids_shape} != {torch.Size([answer[i], packed_length])}" if gpc.get_global_rank() == 0: print(