fix(*): move sequence_parallel to parallel config (#224)

* move sequence_parallel to parallel config

* set the sequece_parallel default value is False

* fix lint

* fix lint

* fix lint
pull/225/head
ytxiong 2023-08-24 09:49:04 +08:00 committed by GitHub
parent 32664328e7
commit a017cab4b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 23 additions and 17 deletions

View File

@ -125,7 +125,6 @@ model = dict(
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
sequence_parallel=False,
)
"""
zero1 parallel:
@ -142,6 +141,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node.
parallel = dict(
zero1=8,
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=False,
)
cudnn_deterministic = False

View File

@ -265,11 +265,13 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
# process the model config
if "use_flash_attn" not in gpc.config.model:
gpc.config.model._add_item("use_flash_attn", True)
if "sequence_parallel" not in gpc.config.model:
gpc.config.model._add_item("sequence_parallel", False)
# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False)
else:
assert not (
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting

View File

@ -56,7 +56,7 @@ class Embedding1D(nn.Module):
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
if gpc.config.model.sequence_parallel:
if gpc.config.parallel.sequence_parallel:
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
return output

View File

@ -62,7 +62,7 @@ class ScaleColumnParallelLinear(nn.Linear):
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
)
@ -111,7 +111,7 @@ class RewardModelLinear(ScaleColumnParallelLinear):
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
)
@ -173,7 +173,7 @@ class FeedForward(nn.Module):
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
@ -182,7 +182,7 @@ class FeedForward(nn.Module):
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
@ -191,7 +191,7 @@ class FeedForward(nn.Module):
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)

View File

@ -121,7 +121,7 @@ class PackedFlashBaseLayer1D(nn.Module):
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
@ -294,7 +294,7 @@ class PackedFlashInternLm1D(nn.Module):
max_position_embeddings=-1,
process_group=gpc.get_group(ParallelMode.TENSOR),
padding_idx=None,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)

View File

@ -82,7 +82,7 @@ class MHA(nn.Module):
3 * embed_dim,
process_group,
bias=True,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
@ -95,7 +95,11 @@ class MHA(nn.Module):
# output projection always have the bias (for now)
self.out_proj = RowParallelLinearTorch(
embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs
embed_dim,
embed_dim,
process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:

View File

@ -52,12 +52,12 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
@contextmanager
def switch_sequence_parallel_mode():
prev_mode = gpc.config.model.sequence_parallel
prev_mode = gpc.config.parallel.sequence_parallel
try:
gpc.config.model.sequence_parallel = False
gpc.config.parallel.sequence_parallel = False
yield
finally:
gpc.config.model.sequence_parallel = prev_mode
gpc.config.parallel.sequence_parallel = prev_mode
def evaluate_on_val_dls(