mirror of https://github.com/InternLM/InternLM
feat(model/linear.py): set block 0 full weight
parent
82204eea59
commit
0d1fa037dd
|
@ -175,6 +175,7 @@ class FeedForward(nn.Module):
|
|||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
block_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -248,38 +249,62 @@ class FSTPFeedForward(nn.Module):
|
|||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
multiple_of: int = 256,
|
||||
block_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = FSTPLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
|
||||
self.w1 = nn.Linear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = nn.Linear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
self.w1 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w2 = FSTPLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.w3 = FSTPLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
w1_o = self.w1(x)
|
||||
|
@ -449,10 +474,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
# print(f"name: {name}", flush=True)
|
||||
if name == "out_proj":
|
||||
self.FSTP_outs.append(child)
|
||||
# self.module_to_index[child] = idx
|
||||
self.module_to_index[child] = idx
|
||||
if name == "Wqkv":
|
||||
self.FSTP_wqkvs.append(child)
|
||||
# self.module_to_index[child] = idx
|
||||
self.module_to_index[child] = idx
|
||||
if isinstance(child, FSTPLinear):
|
||||
self.module_to_index[child] = idx
|
||||
self.block_module[idx][index] = child
|
||||
|
@ -489,6 +514,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
self._all_gather_block_weight(block_index + 1)
|
||||
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)
|
||||
|
||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
||||
block_index = self.block_to_index[block]
|
||||
|
@ -512,14 +538,13 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
del self.block_handles[block]
|
||||
for module in fsdp_modules:
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
|
||||
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||
block_index = self.module_to_index[module]
|
||||
if block_index != 0:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, input, output):
|
||||
if module in self.FSTP_global_weights:
|
||||
del self.FSTP_global_weights[module]
|
||||
|
@ -558,46 +583,48 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
del self.block_handles[block]
|
||||
for module in fsdp_modules:
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
|
||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
||||
block_index = self.module_to_index[module]
|
||||
name_index = self.module_name_index[module]
|
||||
if name_index == 4:
|
||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
if block_index != 0:
|
||||
if name_index == 4:
|
||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
if name_index != 0:
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.FSTP_global_handle[module]
|
||||
handler.wait()
|
||||
if name_index != 0:
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||
next_module.weight, self.process_group, async_op=True
|
||||
)
|
||||
self.FSTP_global_handle[next_module] = weights_handler
|
||||
|
||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||
del self.FSTP_global_weights[module]
|
||||
if module in self.FSTP_global_weights:
|
||||
del self.FSTP_global_weights[module]
|
||||
|
||||
# for block in self.FSTP_blocks:
|
||||
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||
# block.register_forward_hook(_post_forward_hook_for_block)
|
||||
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
||||
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
||||
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||
# block.register_forward_hook(_post_forward_hook_for_block)
|
||||
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
||||
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
||||
|
||||
for out_proj in self.FSTP_outs:
|
||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||
|
||||
# for wqkv in self.FSTP_wqkvs:
|
||||
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
||||
|
||||
|
||||
for module in self.FSTP_modules:
|
||||
module.register_forward_pre_hook(_pre_forward_hook_for_module)
|
||||
module.register_forward_hook(_post_forward_hook_for_module)
|
||||
|
|
|
@ -78,6 +78,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
tp_mode: str = "origin_tp",
|
||||
block_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
|
@ -103,6 +104,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
tp_mode=tp_mode,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(drop_rate)
|
||||
|
@ -123,6 +125,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
else:
|
||||
self.mlp = ParallelFusedMLP(
|
||||
|
@ -344,6 +347,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
tp_mode=self.tp_mode,
|
||||
block_idx=lid,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
|
@ -410,7 +414,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
# Evaluation
|
||||
if hidden_states.ndim == 3:
|
||||
hidden_states = self.head(hidden_states, gather_dim=1)
|
||||
else: # Training
|
||||
else: # Training
|
||||
hidden_states = self.head(hidden_states, gather_dim=0)
|
||||
|
||||
if not self.parallel_output:
|
||||
|
|
|
@ -51,7 +51,6 @@ class _SeqAllToAll(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
||||
|
||||
ctx.group = group
|
||||
ctx.scatter_idx = scatter_idx
|
||||
ctx.gather_idx = gather_idx
|
||||
|
@ -91,7 +90,6 @@ class DistributedAttention(torch.nn.Module):
|
|||
second_scatter_idx: int = 0,
|
||||
second_gather_idx: int = 1,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
self.local_attn = local_attention
|
||||
self.spg = sequence_process_group
|
||||
|
@ -178,6 +176,7 @@ class MHA(nn.Module):
|
|||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
tp_mode: str = "origin_tp",
|
||||
block_idx: int = 0,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
@ -206,14 +205,23 @@ class MHA(nn.Module):
|
|||
|
||||
# notice here should change bias=True
|
||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
self.Wqkv = Wqkv_cls(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
||||
Wqkv_cls = nn.Linear
|
||||
self.Wqkv = Wqkv_cls(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=False,
|
||||
**factory_kwargs,
|
||||
)
|
||||
else:
|
||||
self.Wqkv = Wqkv_cls(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
|
@ -227,14 +235,23 @@ class MHA(nn.Module):
|
|||
|
||||
# output projection always have the bias (for now)
|
||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
process_group,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
||||
out_proj_cls = nn.Linear
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=False,
|
||||
**factory_kwargs,
|
||||
)
|
||||
else:
|
||||
self.out_proj = out_proj_cls(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
process_group,
|
||||
bias=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
for name in ["out_proj", "Wqkv"]:
|
||||
|
|
|
@ -110,8 +110,8 @@ def initialize_model():
|
|||
model = wrap_FSDP_model(model)
|
||||
|
||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||
# handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||
handler._register_sync_parameters_hook()
|
||||
gpc.config.fstp_handler = handler
|
||||
return model
|
||||
|
@ -396,6 +396,9 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
|||
)
|
||||
|
||||
|
||||
tgs_list = []
|
||||
|
||||
|
||||
@llm_timeout(func_name="record_current_batch_training_metrics")
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
|
@ -568,3 +571,9 @@ def record_current_batch_training_metrics(
|
|||
step_count=batch_count,
|
||||
cur_step_loss=loss.item(),
|
||||
)
|
||||
|
||||
if batch_count >= 5:
|
||||
tgs_list.append(tgs_origin)
|
||||
if batch_count == gpc.config.data.total_steps - 1:
|
||||
print(tgs_list, flush=True)
|
||||
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
||||
|
|
Loading…
Reference in New Issue