feat(model/linear.py): set block 0 full weight

pull/407/head
huangting4201 2023-10-16 20:13:59 +08:00
parent 82204eea59
commit 0d1fa037dd
4 changed files with 131 additions and 74 deletions

View File

@ -175,6 +175,7 @@ class FeedForward(nn.Module):
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
multiple_of: int = 256, multiple_of: int = 256,
block_idx: int = 0,
): ):
super().__init__() super().__init__()
@ -248,38 +249,62 @@ class FSTPFeedForward(nn.Module):
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
multiple_of: int = 256, multiple_of: int = 256,
block_idx: int = 0,
): ):
super().__init__() super().__init__()
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
self.w1 = FSTPLinear( if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
in_features, self.w1 = nn.Linear(
hidden_features, in_features,
process_group, hidden_features,
bias, bias,
sequence_parallel=gpc.config.parallel.sequence_parallel, device=device,
device=device, dtype=dtype,
dtype=dtype, )
) self.w2 = nn.Linear(
self.w2 = FSTPLinear( in_features,
in_features, hidden_features,
hidden_features, bias,
process_group, device=device,
bias, dtype=dtype,
sequence_parallel=gpc.config.parallel.sequence_parallel, )
device=device, self.w3 = nn.Linear(
dtype=dtype, hidden_features,
) out_features,
self.w3 = FSTPLinear( bias=bias,
hidden_features, device=device,
out_features, dtype=dtype,
process_group, )
bias=bias, else:
sequence_parallel=gpc.config.parallel.sequence_parallel, self.w1 = FSTPLinear(
device=device, in_features,
dtype=dtype, hidden_features,
) process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
def forward(self, x): def forward(self, x):
w1_o = self.w1(x) w1_o = self.w1(x)
@ -449,10 +474,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# print(f"name: {name}", flush=True) # print(f"name: {name}", flush=True)
if name == "out_proj": if name == "out_proj":
self.FSTP_outs.append(child) self.FSTP_outs.append(child)
# self.module_to_index[child] = idx self.module_to_index[child] = idx
if name == "Wqkv": if name == "Wqkv":
self.FSTP_wqkvs.append(child) self.FSTP_wqkvs.append(child)
# self.module_to_index[child] = idx self.module_to_index[child] = idx
if isinstance(child, FSTPLinear): if isinstance(child, FSTPLinear):
self.module_to_index[child] = idx self.module_to_index[child] = idx
self.block_module[idx][index] = child self.block_module[idx][index] = child
@ -489,6 +514,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# start the all-gather for next block # start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER: if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight(block_index + 1) self._all_gather_block_weight(block_index + 1)
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)
def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
block_index = self.block_to_index[block] block_index = self.block_to_index[block]
@ -513,7 +539,6 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for module in fsdp_modules: for module in fsdp_modules:
del self.FSTP_global_weights[module] del self.FSTP_global_weights[module]
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
block_index = self.module_to_index[module] block_index = self.module_to_index[module]
if block_index != 0: if block_index != 0:
@ -562,35 +587,37 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
def _pre_backward_hook_for_module(module: nn.Module, grad_output): def _pre_backward_hook_for_module(module: nn.Module, grad_output):
block_index = self.module_to_index[module] block_index = self.module_to_index[module]
name_index = self.module_name_index[module] name_index = self.module_name_index[module]
if name_index == 4: if block_index != 0:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) if name_index == 4:
weight_handler.wait() total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[module] = total_weight weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module # start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1] next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True next_module.weight, self.process_group, async_op=True
) )
self.FSTP_global_handle[next_module] = weights_handler self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
def _post_backward_hook_for_module(module, grad_input, grad_output): def _post_backward_hook_for_module(module, grad_input, grad_output):
del self.FSTP_global_weights[module] if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
# for block in self.FSTP_blocks: # for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block) # block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block) # block.register_forward_hook(_post_forward_hook_for_block)
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block) # block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block) # block.register_full_backward_hook(_post_backward_hook_for_block)
for out_proj in self.FSTP_outs: for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)

View File

@ -78,6 +78,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_swiglu: bool = True, use_swiglu: bool = True,
use_flash_attn: bool = True, use_flash_attn: bool = True,
tp_mode: str = "origin_tp", tp_mode: str = "origin_tp",
block_idx: int = 0,
): ):
super().__init__() super().__init__()
self.checkpoint = checkpoint self.checkpoint = checkpoint
@ -103,6 +104,7 @@ class PackedFlashBaseLayer1D(nn.Module):
device=device, device=device,
dtype=dtype, dtype=dtype,
tp_mode=tp_mode, tp_mode=tp_mode,
block_idx=block_idx,
) )
self.dropout1 = nn.Dropout(drop_rate) self.dropout1 = nn.Dropout(drop_rate)
@ -123,6 +125,7 @@ class PackedFlashBaseLayer1D(nn.Module):
bias=False, bias=False,
device=device, device=device,
dtype=dtype, dtype=dtype,
block_idx=block_idx,
) )
else: else:
self.mlp = ParallelFusedMLP( self.mlp = ParallelFusedMLP(
@ -344,6 +347,7 @@ class PackedFlashInternLm1D(nn.Module):
use_swiglu=use_swiglu, use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode, tp_mode=self.tp_mode,
block_idx=lid,
) )
for lid in range(num_layers) for lid in range(num_layers)
] ]
@ -410,7 +414,7 @@ class PackedFlashInternLm1D(nn.Module):
# Evaluation # Evaluation
if hidden_states.ndim == 3: if hidden_states.ndim == 3:
hidden_states = self.head(hidden_states, gather_dim=1) hidden_states = self.head(hidden_states, gather_dim=1)
else: # Training else: # Training
hidden_states = self.head(hidden_states, gather_dim=0) hidden_states = self.head(hidden_states, gather_dim=0)
if not self.parallel_output: if not self.parallel_output:

View File

@ -51,7 +51,6 @@ class _SeqAllToAll(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
ctx.group = group ctx.group = group
ctx.scatter_idx = scatter_idx ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx ctx.gather_idx = gather_idx
@ -91,7 +90,6 @@ class DistributedAttention(torch.nn.Module):
second_scatter_idx: int = 0, second_scatter_idx: int = 0,
second_gather_idx: int = 1, second_gather_idx: int = 1,
) -> None: ) -> None:
super().__init__() super().__init__()
self.local_attn = local_attention self.local_attn = local_attention
self.spg = sequence_process_group self.spg = sequence_process_group
@ -178,6 +176,7 @@ class MHA(nn.Module):
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
tp_mode: str = "origin_tp", tp_mode: str = "origin_tp",
block_idx: int = 0,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
@ -206,14 +205,23 @@ class MHA(nn.Module):
# notice here should change bias=True # notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
self.Wqkv = Wqkv_cls( if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
embed_dim, Wqkv_cls = nn.Linear
3 * embed_dim, self.Wqkv = Wqkv_cls(
process_group, embed_dim,
bias=False, 3 * embed_dim,
sequence_parallel=gpc.config.parallel.sequence_parallel, bias=False,
**factory_kwargs, **factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577 )
else:
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
@ -227,14 +235,23 @@ class MHA(nn.Module):
# output projection always have the bias (for now) # output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
self.out_proj = out_proj_cls( if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
embed_dim, out_proj_cls = nn.Linear
embed_dim, self.out_proj = out_proj_cls(
process_group, embed_dim,
bias=False, embed_dim,
sequence_parallel=gpc.config.parallel.sequence_parallel, bias=False,
**factory_kwargs, **factory_kwargs,
) )
else:
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module # need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1: if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["out_proj", "Wqkv"]: for name in ["out_proj", "Wqkv"]:

View File

@ -110,8 +110,8 @@ def initialize_model():
model = wrap_FSDP_model(model) model = wrap_FSDP_model(model)
if gpc.config.parallel["tensor"]["mode"] == "fstp": if gpc.config.parallel["tensor"]["mode"] == "fstp":
# handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) # handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR))
handler._register_sync_parameters_hook() handler._register_sync_parameters_hook()
gpc.config.fstp_handler = handler gpc.config.fstp_handler = handler
return model return model
@ -396,6 +396,9 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
) )
tgs_list = []
@llm_timeout(func_name="record_current_batch_training_metrics") @llm_timeout(func_name="record_current_batch_training_metrics")
def record_current_batch_training_metrics( def record_current_batch_training_metrics(
get_tflops_func, get_tflops_func,
@ -568,3 +571,9 @@ def record_current_batch_training_metrics(
step_count=batch_count, step_count=batch_count,
cur_step_loss=loss.item(), cur_step_loss=loss.item(),
) )
if batch_count >= 5:
tgs_list.append(tgs_origin)
if batch_count == gpc.config.data.total_steps - 1:
print(tgs_list, flush=True)
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)