mirror of https://github.com/InternLM/InternLM
fix(pipeline): modify the sequence_parallel in pipeline (#227)
* move sequence_parallel to parallel config * set the sequece_parallel default value is False * fix lint * fix lint * fix lint * modify the sequence_parallel in pppull/228/head
parent
9eec3d9465
commit
9cd1e0314e
|
@ -30,7 +30,7 @@ def get_tensor_shape():
|
|||
|
||||
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
||||
if gpc.config.model.use_flash_attn:
|
||||
if gpc.config.model.sequence_parallel:
|
||||
if gpc.config.parallel.sequence_parallel:
|
||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
|
||||
|
@ -140,7 +140,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
and gpc.get_world_size(ParallelMode.TENSOR) > 1
|
||||
)
|
||||
|
||||
if gpc.config.model.sequence_parallel:
|
||||
if gpc.config.parallel.sequence_parallel:
|
||||
self.scatter_gather_tensors = False
|
||||
|
||||
# cache for the batch data
|
||||
|
|
Loading…
Reference in New Issue