[hotfix] fix bug in sequence parallel test (#4887)

pull/4864/head
littsk 2023-10-11 19:30:41 +08:00 committed by GitHub
parent fdec650bb4
commit ffd9a3cbc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -160,7 +160,7 @@ def run_forward_backward_with_hybrid_plugin(
input_shape = data["input_ids"].shape input_shape = data["input_ids"].shape
for k, v in data.items(): for k, v in data.items():
if v.shape == input_shape: if v.shape == input_shape:
data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) data[k] = v.repeat((1, ) * (v.dim() - 1) + (times,))
sharded_model.train() sharded_model.train()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None: