[fix] fix test_shard_llama

pull/6083/head
duanjunwen 2024-10-25 09:01:13 +00:00
parent d0ec221b38
commit cc0dfddcbc
3 changed files with 24 additions and 19 deletions

View File

@ -394,8 +394,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
# if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv: if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
# return [] return []
llama_model = self.model.model llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if ( if (
@ -403,20 +403,26 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
and self.pipeline_stage_manager.num_stages > 1 and self.pipeline_stage_manager.num_stages > 1
): ):
# tie weights # tie weights
if self.pipeline_stage_manager.use_zbv:
return [
{
0: llama_model.embed_tokens.weight,
0: self.model.lm_head.weight,
}
]
else:
return [ return [
{ {
0: llama_model.embed_tokens.weight, 0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
} }
] ]
# if self.pipeline_stage_manager.use_zbv:
# return [
# {
# 0: llama_model.embed_tokens.weight,
# 0: self.model.lm_head.weight,
# }
# ]
# else:
# return [
# {
# 0: llama_model.embed_tokens.weight,
# self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
# }
# ]
return [] return []

View File

@ -237,7 +237,6 @@ def main():
).get_v_schedule() ).get_v_schedule()
else: else:
scheduler_nodes = None scheduler_nodes = None
# print(f"{dist.get_rank()} {scheduler_nodes[]} ")
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(
tp_size=args.tp, tp_size=args.tp,

View File

@ -923,9 +923,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@parameterize( @parameterize(
"config", "config",
[ [
# (1, 2, 2, 1), # Pass (1, 2, 2, 1), # Pass
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture; # TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
(0, 4, 1, 1), # (0, 4, 1, 1),
# (1, 2, 1, 2), # (1, 2, 1, 2),
# (1, 1, 2, 2), # (1, 1, 2, 2),
], ],