mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix bug in sequence parallel test (#4887)
parent
fdec650bb4
commit
ffd9a3cbc9
|
@ -160,7 +160,7 @@ def run_forward_backward_with_hybrid_plugin(
|
||||||
input_shape = data["input_ids"].shape
|
input_shape = data["input_ids"].shape
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if v.shape == input_shape:
|
if v.shape == input_shape:
|
||||||
data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,))
|
data[k] = v.repeat((1, ) * (v.dim() - 1) + (times,))
|
||||||
|
|
||||||
sharded_model.train()
|
sharded_model.train()
|
||||||
if booster.plugin.stage_manager is not None:
|
if booster.plugin.stage_manager is not None:
|
||||||
|
|
Loading…
Reference in New Issue