mirror of https://github.com/InternLM/InternLM
fix(*): move sequence_parallel to parallel config (#224)
* move sequence_parallel to parallel config * set the sequece_parallel default value is False * fix lint * fix lint * fix lintpull/225/head
parent
32664328e7
commit
a017cab4b3
|
@ -125,7 +125,6 @@ model = dict(
|
|||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
sequence_parallel=False,
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
|
@ -142,6 +141,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
|||
parallel = dict(
|
||||
zero1=8,
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=False,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
|
|
|
@ -265,11 +265,13 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
|||
# process the model config
|
||||
if "use_flash_attn" not in gpc.config.model:
|
||||
gpc.config.model._add_item("use_flash_attn", True)
|
||||
if "sequence_parallel" not in gpc.config.model:
|
||||
gpc.config.model._add_item("sequence_parallel", False)
|
||||
|
||||
# process the parallel config
|
||||
if "sequence_parallel" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||
else:
|
||||
assert not (
|
||||
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||
), "sequence parallel does not support use_flash_attn=False"
|
||||
|
||||
# feishu webhook address for alerting
|
||||
|
|
|
@ -56,7 +56,7 @@ class Embedding1D(nn.Module):
|
|||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||
|
||||
if gpc.config.model.sequence_parallel:
|
||||
if gpc.config.parallel.sequence_parallel:
|
||||
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
||||
|
||||
return output
|
||||
|
|
|
@ -62,7 +62,7 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
|
@ -111,7 +111,7 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
|
@ -173,7 +173,7 @@ class FeedForward(nn.Module):
|
|||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -182,7 +182,7 @@ class FeedForward(nn.Module):
|
|||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -191,7 +191,7 @@ class FeedForward(nn.Module):
|
|||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -121,7 +121,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
device=device,
|
||||
|
@ -294,7 +294,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
max_position_embeddings=-1,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
padding_idx=None,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -82,7 +82,7 @@ class MHA(nn.Module):
|
|||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=True,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
|
||||
|
@ -95,7 +95,11 @@ class MHA(nn.Module):
|
|||
|
||||
# output projection always have the bias (for now)
|
||||
self.out_proj = RowParallelLinearTorch(
|
||||
embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
process_group,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
|
|
|
@ -52,12 +52,12 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
|||
|
||||
@contextmanager
|
||||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.model.sequence_parallel
|
||||
prev_mode = gpc.config.parallel.sequence_parallel
|
||||
try:
|
||||
gpc.config.model.sequence_parallel = False
|
||||
gpc.config.parallel.sequence_parallel = False
|
||||
yield
|
||||
finally:
|
||||
gpc.config.model.sequence_parallel = prev_mode
|
||||
gpc.config.parallel.sequence_parallel = prev_mode
|
||||
|
||||
|
||||
def evaluate_on_val_dls(
|
||||
|
|
Loading…
Reference in New Issue