feat(model/linear.py): set block 0 full weight

pull/407/head
huangting4201 2023-10-16 20:13:59 +08:00
parent 82204eea59
commit 0d1fa037dd
4 changed files with 131 additions and 74 deletions

View File

@ -175,6 +175,7 @@ class FeedForward(nn.Module):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
multiple_of: int = 256,
block_idx: int = 0,
):
super().__init__()
@ -248,38 +249,62 @@ class FSTPFeedForward(nn.Module):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
multiple_of: int = 256,
block_idx: int = 0,
):
super().__init__()
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
self.w1 = nn.Linear(
in_features,
hidden_features,
bias,
device=device,
dtype=dtype,
)
self.w2 = nn.Linear(
in_features,
hidden_features,
bias,
device=device,
dtype=dtype,
)
self.w3 = nn.Linear(
hidden_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
else:
self.w1 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
def forward(self, x):
w1_o = self.w1(x)
@ -449,10 +474,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# print(f"name: {name}", flush=True)
if name == "out_proj":
self.FSTP_outs.append(child)
# self.module_to_index[child] = idx
self.module_to_index[child] = idx
if name == "Wqkv":
self.FSTP_wqkvs.append(child)
# self.module_to_index[child] = idx
self.module_to_index[child] = idx
if isinstance(child, FSTPLinear):
self.module_to_index[child] = idx
self.block_module[idx][index] = child
@ -489,6 +514,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight(block_index + 1)
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
block_index = self.block_to_index[block]
@ -512,14 +538,13 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
del self.block_handles[block]
for module in fsdp_modules:
del self.FSTP_global_weights[module]
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
block_index = self.module_to_index[module]
if block_index != 0:
handler = self.FSTP_global_handle[module]
handler.wait()
def _post_forward_hook_for_module(module: nn.Module, input, output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
@ -558,46 +583,48 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
del self.block_handles[block]
for module in fsdp_modules:
del self.FSTP_global_weights[module]
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
block_index = self.module_to_index[module]
name_index = self.module_name_index[module]
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
if block_index != 0:
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
def _post_backward_hook_for_module(module, grad_input, grad_output):
del self.FSTP_global_weights[module]
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
# for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block)
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block)
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block)
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block)
for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
# for wqkv in self.FSTP_wqkvs:
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
for module in self.FSTP_modules:
module.register_forward_pre_hook(_pre_forward_hook_for_module)
module.register_forward_hook(_post_forward_hook_for_module)

View File

@ -78,6 +78,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_swiglu: bool = True,
use_flash_attn: bool = True,
tp_mode: str = "origin_tp",
block_idx: int = 0,
):
super().__init__()
self.checkpoint = checkpoint
@ -103,6 +104,7 @@ class PackedFlashBaseLayer1D(nn.Module):
device=device,
dtype=dtype,
tp_mode=tp_mode,
block_idx=block_idx,
)
self.dropout1 = nn.Dropout(drop_rate)
@ -123,6 +125,7 @@ class PackedFlashBaseLayer1D(nn.Module):
bias=False,
device=device,
dtype=dtype,
block_idx=block_idx,
)
else:
self.mlp = ParallelFusedMLP(
@ -344,6 +347,7 @@ class PackedFlashInternLm1D(nn.Module):
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode,
block_idx=lid,
)
for lid in range(num_layers)
]
@ -410,7 +414,7 @@ class PackedFlashInternLm1D(nn.Module):
# Evaluation
if hidden_states.ndim == 3:
hidden_states = self.head(hidden_states, gather_dim=1)
else: # Training
else: # Training
hidden_states = self.head(hidden_states, gather_dim=0)
if not self.parallel_output:

View File

@ -51,7 +51,6 @@ class _SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
@ -91,7 +90,6 @@ class DistributedAttention(torch.nn.Module):
second_scatter_idx: int = 0,
second_gather_idx: int = 1,
) -> None:
super().__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
@ -178,6 +176,7 @@ class MHA(nn.Module):
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
tp_mode: str = "origin_tp",
block_idx: int = 0,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
@ -206,14 +205,23 @@ class MHA(nn.Module):
# notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
Wqkv_cls = nn.Linear
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
bias=False,
**factory_kwargs,
)
else:
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
@ -227,14 +235,23 @@ class MHA(nn.Module):
# output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
out_proj_cls = nn.Linear
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
bias=False,
**factory_kwargs,
)
else:
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["out_proj", "Wqkv"]:

View File

@ -110,8 +110,8 @@ def initialize_model():
model = wrap_FSDP_model(model)
if gpc.config.parallel["tensor"]["mode"] == "fstp":
# handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
# handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler
return model
@ -396,6 +396,9 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
)
tgs_list = []
@llm_timeout(func_name="record_current_batch_training_metrics")
def record_current_batch_training_metrics(
get_tflops_func,
@ -568,3 +571,9 @@ def record_current_batch_training_metrics(
step_count=batch_count,
cur_step_loss=loss.item(),
)
if batch_count >= 5:
tgs_list.append(tgs_origin)
if batch_count == gpc.config.data.total_steps - 1:
print(tgs_list, flush=True)
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)