fix(pipeline): modify the sequence_parallel in pipeline (#227)

* move sequence_parallel to parallel config

* set the sequece_parallel default value is False

* fix lint

* fix lint

* fix lint

* modify the sequence_parallel in pp
pull/228/head
ytxiong 2023-08-24 14:45:40 +08:00 committed by GitHub
parent 9eec3d9465
commit 9cd1e0314e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -30,7 +30,7 @@ def get_tensor_shape():
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
if gpc.config.model.use_flash_attn:
if gpc.config.model.sequence_parallel:
if gpc.config.parallel.sequence_parallel:
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
@ -140,7 +140,7 @@ class PipelineScheduler(BaseScheduler):
and gpc.get_world_size(ParallelMode.TENSOR) > 1
)
if gpc.config.model.sequence_parallel:
if gpc.config.parallel.sequence_parallel:
self.scatter_gather_tensors = False
# cache for the batch data