Browse Source

[fix] fix test_shard_llama

pull/6083/head
duanjunwen 4 weeks ago
parent
commit
cc0dfddcbc
  1. 38
      colossalai/shardformer/policies/llama.py
  2. 1
      examples/language/llama/benchmark.py
  3. 4
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py

38
colossalai/shardformer/policies/llama.py

@ -394,8 +394,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
# if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
# return []
if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
return []
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
@ -403,20 +403,26 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
if self.pipeline_stage_manager.use_zbv:
return [
{
0: llama_model.embed_tokens.weight,
0: self.model.lm_head.weight,
}
]
else:
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
# if self.pipeline_stage_manager.use_zbv:
# return [
# {
# 0: llama_model.embed_tokens.weight,
# 0: self.model.lm_head.weight,
# }
# ]
# else:
# return [
# {
# 0: llama_model.embed_tokens.weight,
# self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
# }
# ]
return []

1
examples/language/llama/benchmark.py

@ -237,7 +237,6 @@ def main():
).get_v_schedule()
else:
scheduler_nodes = None
# print(f"{dist.get_rank()} {scheduler_nodes[]} ")
plugin = HybridParallelPlugin(
tp_size=args.tp,

4
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -923,9 +923,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@parameterize(
"config",
[
# (1, 2, 2, 1), # Pass
(1, 2, 2, 1), # Pass
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
(0, 4, 1, 1),
# (0, 4, 1, 1),
# (1, 2, 1, 2),
# (1, 1, 2, 2),
],

Loading…
Cancel
Save