mirror of https://github.com/InternLM/InternLM
feat(model/linear.py): set block 0 full weight
parent
82204eea59
commit
0d1fa037dd
|
|
@ -175,6 +175,7 @@ class FeedForward(nn.Module):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
multiple_of: int = 256,
|
multiple_of: int = 256,
|
||||||
|
block_idx: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -248,11 +249,35 @@ class FSTPFeedForward(nn.Module):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
multiple_of: int = 256,
|
multiple_of: int = 256,
|
||||||
|
block_idx: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
|
||||||
|
self.w1 = nn.Linear(
|
||||||
|
in_features,
|
||||||
|
hidden_features,
|
||||||
|
bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.w2 = nn.Linear(
|
||||||
|
in_features,
|
||||||
|
hidden_features,
|
||||||
|
bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.w3 = nn.Linear(
|
||||||
|
hidden_features,
|
||||||
|
out_features,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.w1 = FSTPLinear(
|
self.w1 = FSTPLinear(
|
||||||
in_features,
|
in_features,
|
||||||
hidden_features,
|
hidden_features,
|
||||||
|
|
@ -449,10 +474,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
# print(f"name: {name}", flush=True)
|
# print(f"name: {name}", flush=True)
|
||||||
if name == "out_proj":
|
if name == "out_proj":
|
||||||
self.FSTP_outs.append(child)
|
self.FSTP_outs.append(child)
|
||||||
# self.module_to_index[child] = idx
|
self.module_to_index[child] = idx
|
||||||
if name == "Wqkv":
|
if name == "Wqkv":
|
||||||
self.FSTP_wqkvs.append(child)
|
self.FSTP_wqkvs.append(child)
|
||||||
# self.module_to_index[child] = idx
|
self.module_to_index[child] = idx
|
||||||
if isinstance(child, FSTPLinear):
|
if isinstance(child, FSTPLinear):
|
||||||
self.module_to_index[child] = idx
|
self.module_to_index[child] = idx
|
||||||
self.block_module[idx][index] = child
|
self.block_module[idx][index] = child
|
||||||
|
|
@ -489,6 +514,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
# start the all-gather for next block
|
# start the all-gather for next block
|
||||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||||
self._all_gather_block_weight(block_index + 1)
|
self._all_gather_block_weight(block_index + 1)
|
||||||
|
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)
|
||||||
|
|
||||||
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
|
||||||
block_index = self.block_to_index[block]
|
block_index = self.block_to_index[block]
|
||||||
|
|
@ -513,7 +539,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
for module in fsdp_modules:
|
for module in fsdp_modules:
|
||||||
del self.FSTP_global_weights[module]
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
|
|
||||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
if block_index != 0:
|
if block_index != 0:
|
||||||
|
|
@ -562,6 +587,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
name_index = self.module_name_index[module]
|
name_index = self.module_name_index[module]
|
||||||
|
if block_index != 0:
|
||||||
if name_index == 4:
|
if name_index == 4:
|
||||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
weight_handler.wait()
|
weight_handler.wait()
|
||||||
|
|
@ -584,6 +610,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.FSTP_global_handle[next_module] = weights_handler
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
|
||||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||||
|
if module in self.FSTP_global_weights:
|
||||||
del self.FSTP_global_weights[module]
|
del self.FSTP_global_weights[module]
|
||||||
|
|
||||||
# for block in self.FSTP_blocks:
|
# for block in self.FSTP_blocks:
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
use_flash_attn: bool = True,
|
use_flash_attn: bool = True,
|
||||||
tp_mode: str = "origin_tp",
|
tp_mode: str = "origin_tp",
|
||||||
|
block_idx: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
|
@ -103,6 +104,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tp_mode=tp_mode,
|
tp_mode=tp_mode,
|
||||||
|
block_idx=block_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout1 = nn.Dropout(drop_rate)
|
self.dropout1 = nn.Dropout(drop_rate)
|
||||||
|
|
@ -123,6 +125,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
bias=False,
|
bias=False,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
block_idx=block_idx,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.mlp = ParallelFusedMLP(
|
self.mlp = ParallelFusedMLP(
|
||||||
|
|
@ -344,6 +347,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
use_flash_attn=use_flash_attn,
|
use_flash_attn=use_flash_attn,
|
||||||
tp_mode=self.tp_mode,
|
tp_mode=self.tp_mode,
|
||||||
|
block_idx=lid,
|
||||||
)
|
)
|
||||||
for lid in range(num_layers)
|
for lid in range(num_layers)
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,6 @@ class _SeqAllToAll(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
|
||||||
|
|
||||||
ctx.group = group
|
ctx.group = group
|
||||||
ctx.scatter_idx = scatter_idx
|
ctx.scatter_idx = scatter_idx
|
||||||
ctx.gather_idx = gather_idx
|
ctx.gather_idx = gather_idx
|
||||||
|
|
@ -91,7 +90,6 @@ class DistributedAttention(torch.nn.Module):
|
||||||
second_scatter_idx: int = 0,
|
second_scatter_idx: int = 0,
|
||||||
second_gather_idx: int = 1,
|
second_gather_idx: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.local_attn = local_attention
|
self.local_attn = local_attention
|
||||||
self.spg = sequence_process_group
|
self.spg = sequence_process_group
|
||||||
|
|
@ -178,6 +176,7 @@ class MHA(nn.Module):
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
tp_mode: str = "origin_tp",
|
tp_mode: str = "origin_tp",
|
||||||
|
block_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -206,6 +205,15 @@ class MHA(nn.Module):
|
||||||
|
|
||||||
# notice here should change bias=True
|
# notice here should change bias=True
|
||||||
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||||
|
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
||||||
|
Wqkv_cls = nn.Linear
|
||||||
|
self.Wqkv = Wqkv_cls(
|
||||||
|
embed_dim,
|
||||||
|
3 * embed_dim,
|
||||||
|
bias=False,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.Wqkv = Wqkv_cls(
|
self.Wqkv = Wqkv_cls(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
3 * embed_dim,
|
3 * embed_dim,
|
||||||
|
|
@ -227,6 +235,15 @@ class MHA(nn.Module):
|
||||||
|
|
||||||
# output projection always have the bias (for now)
|
# output projection always have the bias (for now)
|
||||||
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
|
||||||
|
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
|
||||||
|
out_proj_cls = nn.Linear
|
||||||
|
self.out_proj = out_proj_cls(
|
||||||
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=False,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.out_proj = out_proj_cls(
|
self.out_proj = out_proj_cls(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
|
|
|
||||||
|
|
@ -110,8 +110,8 @@ def initialize_model():
|
||||||
model = wrap_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
if gpc.config.parallel["tensor"]["mode"] == "fstp":
|
||||||
# handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
|
||||||
handler._register_sync_parameters_hook()
|
handler._register_sync_parameters_hook()
|
||||||
gpc.config.fstp_handler = handler
|
gpc.config.fstp_handler = handler
|
||||||
return model
|
return model
|
||||||
|
|
@ -396,6 +396,9 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
tgs_list = []
|
||||||
|
|
||||||
|
|
||||||
@llm_timeout(func_name="record_current_batch_training_metrics")
|
@llm_timeout(func_name="record_current_batch_training_metrics")
|
||||||
def record_current_batch_training_metrics(
|
def record_current_batch_training_metrics(
|
||||||
get_tflops_func,
|
get_tflops_func,
|
||||||
|
|
@ -568,3 +571,9 @@ def record_current_batch_training_metrics(
|
||||||
step_count=batch_count,
|
step_count=batch_count,
|
||||||
cur_step_loss=loss.item(),
|
cur_step_loss=loss.item(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if batch_count >= 5:
|
||||||
|
tgs_list.append(tgs_origin)
|
||||||
|
if batch_count == gpc.config.data.total_steps - 1:
|
||||||
|
print(tgs_list, flush=True)
|
||||||
|
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue