mirror of https://github.com/InternLM/InternLM
feat(model/linear.py): set block 0 full weight
parent
82204eea59
commit
0d1fa037dd
|
@ -175,6 +175,7 @@ class FeedForward(nn.Module):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
multiple_of: int = 256,
|
multiple_of: int = 256,
|
||||||
|
block_idx: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -248,38 +249,62 @@ class FSTPFeedForward(nn.Module):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
multiple_of: int = 256,
|
multiple_of: int = 256,
|
||||||
|
block_idx: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
self.w1 = FSTPLinear(
|
if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
|
||||||
in_features,
|
self.w1 = nn.Linear(
|
||||||
hidden_features,
|
in_features,
|
||||||
process_group,
|
hidden_features,
|
||||||
bias,
|
bias,
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
device=device,
|
||||||
device=device,
|
dtype=dtype,
|
||||||
dtype=dtype,
|
)
|
||||||
)
|
self.w2 = nn.Linear(
|
||||||
self.w2 = FSTPLinear(
|
in_features,
|
||||||
in_features,
|
hidden_features,
|
||||||
hidden_features,
|
bias,
|
||||||
process_group,
|
device=device,
|
||||||
bias,
|
dtype=dtype,
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
)
|
||||||
device=device,
|
self.w3 = nn.Linear(
|
||||||
dtype=dtype,
|
hidden_features,
|
||||||
)
|
out_features,
|
||||||
self.w3 = FSTPLinear(
|
bias=bias,
|
||||||
hidden_features,
|
device=device,
|
||||||
out_features,
|
dtype=dtype,
|
||||||
process_group,
|
)
|
||||||
bias=bias,
|
else:
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
self.w1 = FSTPLinear(
|
||||||
device=device,
|
in_features,
|
||||||
dtype=dtype,
|
hidden_features,
|
||||||
)
|
process_group,
|
||||||
|
bias,
|
||||||
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.w2 = FSTPLinear(
|
||||||
|
in_features,
|
||||||
|
hidden_features,
|
||||||
|
process_group,
|
||||||
|
bias,
|
||||||
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.w3 = FSTPLinear(
|
||||||
|
hidden_features,
|
||||||
|
out_features,
|
||||||
|
process_group,
|
||||||
|
bias=bias,
|
||||||
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
w1_o = self.w1(x)
|
w1_o = self.w1(x)
|
||||||
|
@ -449,10 +474,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
# print(f"name: {name}", flush=True)
|
# print(f"name: {name}", flush=True)
|
||||||
if name == "out_proj":
|
if name == "out_proj":
|
||||||
self.FSTP_outs.append(child)
|
self.FSTP_outs.append(child)
|
||||||
# self.module_to_index[child] = idx
|
self.module_to_index[child] = idx
|
||||||
if name == "Wqkv":
|
if name == "Wqkv":
|
||||||
self.FSTP_wqkvs.append(child)
|
self.FSTP_wqkvs.append(child)
|
||||||
# self.module_to_index[child] = idx
|
self.module_to_index[child] = idx
|
||||||
if isinstance(child, FSTPLinear):
|
if isinstance(child, FSTPLinear):
|
||||||
self.module_to_index[child] = idx
|
self.module_to_index[child] = idx
|
||||||
self.block_module[idx][index] = child
|
self.block_module[idx][index] = child
|
||||||
|
@ -489,6 +514,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
# start the all-gather for next block
|
# start the all-gather for next block
|
||||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||||
self._all_gather_block_weight(block_index + 1)
|
self._all_gather_block_weight(block_index + 1)
|
||||||
|
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)
|
||||||
|
|
||||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
||||||
block_index = self.block_to_index[block]
|
block_index = self.block_to_index[block]
|
||||||
|
@ -513,7 +539,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
for module in fsdp_modules:
|
for module in fsdp_modules:
|
||||||
del self.FSTP_global_weights[module]
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
|
|
||||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
if block_index != 0:
|
if block_index != 0:
|
||||||
|
@ -562,35 +587,37 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
name_index = self.module_name_index[module]
|
name_index = self.module_name_index[module]
|
||||||
if name_index == 4:
|
if block_index != 0:
|
||||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
if name_index == 4:
|
||||||
weight_handler.wait()
|
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
self.FSTP_global_weights[module] = total_weight
|
weight_handler.wait()
|
||||||
|
self.FSTP_global_weights[module] = total_weight
|
||||||
|
|
||||||
# start the all-gather for next module
|
# start the all-gather for next module
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
next_module.weight, self.process_group, async_op=True
|
|
||||||
)
|
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
|
||||||
else:
|
|
||||||
handler = self.FSTP_global_handle[module]
|
|
||||||
handler.wait()
|
|
||||||
if name_index != 0:
|
|
||||||
next_module = self.block_module[block_index][name_index - 1]
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
next_module.weight, self.process_group, async_op=True
|
next_module.weight, self.process_group, async_op=True
|
||||||
)
|
)
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
else:
|
||||||
|
handler = self.FSTP_global_handle[module]
|
||||||
|
handler.wait()
|
||||||
|
if name_index != 0:
|
||||||
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
|
||||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||||
del self.FSTP_global_weights[module]
|
if module in self.FSTP_global_weights:
|
||||||
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
# for block in self.FSTP_blocks:
|
# for block in self.FSTP_blocks:
|
||||||
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||||
# block.register_forward_hook(_post_forward_hook_for_block)
|
# block.register_forward_hook(_post_forward_hook_for_block)
|
||||||
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
||||||
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
||||||
|
|
||||||
for out_proj in self.FSTP_outs:
|
for out_proj in self.FSTP_outs:
|
||||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||||
|
|
|
@ -78,6 +78,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
tp_mode: str = "origin_tp",
|
tp_mode: str = "origin_tp",
|
||||||
|
block_idx: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
@ -103,6 +104,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tp_mode=tp_mode,
|
tp_mode=tp_mode,
|
||||||
|
block_idx=block_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout1 = nn.Dropout(drop_rate)
|
self.dropout1 = nn.Dropout(drop_rate)
|
||||||
|
@ -123,6 +125,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
bias=False,
|
bias=False,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
block_idx=block_idx,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.mlp = ParallelFusedMLP(
|
self.mlp = ParallelFusedMLP(
|
||||||
|
@ -344,6 +347,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
tp_mode=self.tp_mode,
|
tp_mode=self.tp_mode,
|
||||||
|
block_idx=lid,
|
||||||
)
|
)
|
||||||
for lid in range(num_layers)
|
for lid in range(num_layers)
|
||||||
]
|
]
|
||||||
|
@ -410,7 +414,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if hidden_states.ndim == 3:
|
if hidden_states.ndim == 3:
|
||||||
hidden_states = self.head(hidden_states, gather_dim=1)
|
hidden_states = self.head(hidden_states, gather_dim=1)
|
||||||
else: # Training
|
else: # Training
|
||||||
hidden_states = self.head(hidden_states, gather_dim=0)
|
hidden_states = self.head(hidden_states, gather_dim=0)
|
||||||
|
|
||||||
if not self.parallel_output:
|
if not self.parallel_output:
|
||||||
|
|
|
@ -51,7 +51,6 @@ class _SeqAllToAll(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
||||||
|
|
||||||
ctx.group = group
|
ctx.group = group
|
||||||
ctx.scatter_idx = scatter_idx
|
ctx.scatter_idx = scatter_idx
|
||||||
ctx.gather_idx = gather_idx
|
ctx.gather_idx = gather_idx
|
||||||
|
@ -91,7 +90,6 @@ class DistributedAttention(torch.nn.Module):
|
||||||
second_scatter_idx: int = 0,
|
second_scatter_idx: int = 0,
|
||||||
second_gather_idx: int = 1,
|
second_gather_idx: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.local_attn = local_attention
|
self.local_attn = local_attention
|
||||||
self.spg = sequence_process_group
|
self.spg = sequence_process_group
|
||||||
|
@ -178,6 +176,7 @@ class MHA(nn.Module):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
tp_mode: str = "origin_tp",
|
tp_mode: str = "origin_tp",
|
||||||
|
block_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -206,14 +205,23 @@ class MHA(nn.Module):
|
||||||
|
|
||||||
# notice here should change bias=True
|
# notice here should change bias=True
|
||||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||||
self.Wqkv = Wqkv_cls(
|
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
||||||
embed_dim,
|
Wqkv_cls = nn.Linear
|
||||||
3 * embed_dim,
|
self.Wqkv = Wqkv_cls(
|
||||||
process_group,
|
embed_dim,
|
||||||
bias=False,
|
3 * embed_dim,
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
bias=False,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
) # according to https://spaces.ac.cn/archives/9577
|
)
|
||||||
|
else:
|
||||||
|
self.Wqkv = Wqkv_cls(
|
||||||
|
embed_dim,
|
||||||
|
3 * embed_dim,
|
||||||
|
process_group,
|
||||||
|
bias=False,
|
||||||
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
|
**factory_kwargs,
|
||||||
|
) # according to https://spaces.ac.cn/archives/9577
|
||||||
|
|
||||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||||
|
@ -227,14 +235,23 @@ class MHA(nn.Module):
|
||||||
|
|
||||||
# output projection always have the bias (for now)
|
# output projection always have the bias (for now)
|
||||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||||
self.out_proj = out_proj_cls(
|
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
||||||
embed_dim,
|
out_proj_cls = nn.Linear
|
||||||
embed_dim,
|
self.out_proj = out_proj_cls(
|
||||||
process_group,
|
embed_dim,
|
||||||
bias=False,
|
embed_dim,
|
||||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
bias=False,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.out_proj = out_proj_cls(
|
||||||
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
process_group,
|
||||||
|
bias=False,
|
||||||
|
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
for name in ["out_proj", "Wqkv"]:
|
for name in ["out_proj", "Wqkv"]:
|
||||||
|
|
|
@ -110,8 +110,8 @@ def initialize_model():
|
||||||
model = wrap_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||||
# handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
handler._register_sync_parameters_hook()
|
handler._register_sync_parameters_hook()
|
||||||
gpc.config.fstp_handler = handler
|
gpc.config.fstp_handler = handler
|
||||||
return model
|
return model
|
||||||
|
@ -396,6 +396,9 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
tgs_list = []
|
||||||
|
|
||||||
|
|
||||||
@llm_timeout(func_name="record_current_batch_training_metrics")
|
@llm_timeout(func_name="record_current_batch_training_metrics")
|
||||||
def record_current_batch_training_metrics(
|
def record_current_batch_training_metrics(
|
||||||
get_tflops_func,
|
get_tflops_func,
|
||||||
|
@ -568,3 +571,9 @@ def record_current_batch_training_metrics(
|
||||||
step_count=batch_count,
|
step_count=batch_count,
|
||||||
cur_step_loss=loss.item(),
|
cur_step_loss=loss.item(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if batch_count >= 5:
|
||||||
|
tgs_list.append(tgs_origin)
|
||||||
|
if batch_count == gpc.config.data.total_steps - 1:
|
||||||
|
print(tgs_list, flush=True)
|
||||||
|
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
||||||
|
|
Loading…
Reference in New Issue