mirror of https://github.com/InternLM/InternLM
fix(pipeline): modify the sequence_parallel in pipeline (#227)
* move sequence_parallel to parallel config * set the sequece_parallel default value is False * fix lint * fix lint * fix lint * modify the sequence_parallel in pppull/228/head
parent
9eec3d9465
commit
9cd1e0314e
|
@ -30,7 +30,7 @@ def get_tensor_shape():
|
||||||
|
|
||||||
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
||||||
if gpc.config.model.use_flash_attn:
|
if gpc.config.model.use_flash_attn:
|
||||||
if gpc.config.model.sequence_parallel:
|
if gpc.config.parallel.sequence_parallel:
|
||||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||||
tensor_shape = (
|
tensor_shape = (
|
||||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
|
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
|
||||||
|
@ -140,7 +140,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
and gpc.get_world_size(ParallelMode.TENSOR) > 1
|
and gpc.get_world_size(ParallelMode.TENSOR) > 1
|
||||||
)
|
)
|
||||||
|
|
||||||
if gpc.config.model.sequence_parallel:
|
if gpc.config.parallel.sequence_parallel:
|
||||||
self.scatter_gather_tensors = False
|
self.scatter_gather_tensors = False
|
||||||
|
|
||||||
# cache for the batch data
|
# cache for the batch data
|
||||||
|
|
Loading…
Reference in New Issue