mirror of https://github.com/InternLM/InternLM
fix(*): move sequence_parallel to parallel config (#224)
* move sequence_parallel to parallel config * set the sequece_parallel default value is False * fix lint * fix lint * fix lintpull/225/head
parent
32664328e7
commit
a017cab4b3
|
|
@ -125,7 +125,6 @@ model = dict(
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
use_flash_attn=True,
|
use_flash_attn=True,
|
||||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||||
sequence_parallel=False,
|
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
zero1 parallel:
|
zero1 parallel:
|
||||||
|
|
@ -142,6 +141,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=8,
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
|
sequence_parallel=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
cudnn_deterministic = False
|
cudnn_deterministic = False
|
||||||
|
|
|
||||||
|
|
@ -265,11 +265,13 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||||
# process the model config
|
# process the model config
|
||||||
if "use_flash_attn" not in gpc.config.model:
|
if "use_flash_attn" not in gpc.config.model:
|
||||||
gpc.config.model._add_item("use_flash_attn", True)
|
gpc.config.model._add_item("use_flash_attn", True)
|
||||||
if "sequence_parallel" not in gpc.config.model:
|
|
||||||
gpc.config.model._add_item("sequence_parallel", False)
|
# process the parallel config
|
||||||
|
if "sequence_parallel" not in gpc.config.parallel:
|
||||||
|
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||||
else:
|
else:
|
||||||
assert not (
|
assert not (
|
||||||
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||||
), "sequence parallel does not support use_flash_attn=False"
|
), "sequence parallel does not support use_flash_attn=False"
|
||||||
|
|
||||||
# feishu webhook address for alerting
|
# feishu webhook address for alerting
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ class Embedding1D(nn.Module):
|
||||||
|
|
||||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||||
|
|
||||||
if gpc.config.model.sequence_parallel:
|
if gpc.config.parallel.sequence_parallel:
|
||||||
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ class ScaleColumnParallelLinear(nn.Linear):
|
||||||
weight,
|
weight,
|
||||||
self.bias,
|
self.bias,
|
||||||
process_group=self.process_group,
|
process_group=self.process_group,
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -111,7 +111,7 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
||||||
weight,
|
weight,
|
||||||
self.bias,
|
self.bias,
|
||||||
process_group=self.process_group,
|
process_group=self.process_group,
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -173,7 +173,7 @@ class FeedForward(nn.Module):
|
||||||
hidden_features,
|
hidden_features,
|
||||||
process_group,
|
process_group,
|
||||||
bias,
|
bias,
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
@ -182,7 +182,7 @@ class FeedForward(nn.Module):
|
||||||
hidden_features,
|
hidden_features,
|
||||||
process_group,
|
process_group,
|
||||||
bias,
|
bias,
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
@ -191,7 +191,7 @@ class FeedForward(nn.Module):
|
||||||
out_features,
|
out_features,
|
||||||
process_group,
|
process_group,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
bias1=False,
|
bias1=False,
|
||||||
bias2=False,
|
bias2=False,
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
checkpoint_lvl=0,
|
checkpoint_lvl=0,
|
||||||
heuristic="auto",
|
heuristic="auto",
|
||||||
device=device,
|
device=device,
|
||||||
|
|
@ -294,7 +294,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
max_position_embeddings=-1,
|
max_position_embeddings=-1,
|
||||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||||
padding_idx=None,
|
padding_idx=None,
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ class MHA(nn.Module):
|
||||||
3 * embed_dim,
|
3 * embed_dim,
|
||||||
process_group,
|
process_group,
|
||||||
bias=True,
|
bias=True,
|
||||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
) # according to https://spaces.ac.cn/archives/9577
|
) # according to https://spaces.ac.cn/archives/9577
|
||||||
|
|
||||||
|
|
@ -95,7 +95,11 @@ class MHA(nn.Module):
|
||||||
|
|
||||||
# output projection always have the bias (for now)
|
# output projection always have the bias (for now)
|
||||||
self.out_proj = RowParallelLinearTorch(
|
self.out_proj = RowParallelLinearTorch(
|
||||||
embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
process_group,
|
||||||
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
|
|
|
||||||
|
|
@ -52,12 +52,12 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def switch_sequence_parallel_mode():
|
def switch_sequence_parallel_mode():
|
||||||
prev_mode = gpc.config.model.sequence_parallel
|
prev_mode = gpc.config.parallel.sequence_parallel
|
||||||
try:
|
try:
|
||||||
gpc.config.model.sequence_parallel = False
|
gpc.config.parallel.sequence_parallel = False
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
gpc.config.model.sequence_parallel = prev_mode
|
gpc.config.parallel.sequence_parallel = prev_mode
|
||||||
|
|
||||||
|
|
||||||
def evaluate_on_val_dls(
|
def evaluate_on_val_dls(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue