mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix test_shard_llama
parent
d0ec221b38
commit
cc0dfddcbc
|
@ -394,8 +394,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
# if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
|
||||
# return []
|
||||
if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
|
||||
return []
|
||||
llama_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
|
@ -403,20 +403,26 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
if self.pipeline_stage_manager.use_zbv:
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
0: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
# if self.pipeline_stage_manager.use_zbv:
|
||||
# return [
|
||||
# {
|
||||
# 0: llama_model.embed_tokens.weight,
|
||||
# 0: self.model.lm_head.weight,
|
||||
# }
|
||||
# ]
|
||||
# else:
|
||||
# return [
|
||||
# {
|
||||
# 0: llama_model.embed_tokens.weight,
|
||||
# self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
# }
|
||||
# ]
|
||||
return []
|
||||
|
||||
|
||||
|
|
|
@ -237,7 +237,6 @@ def main():
|
|||
).get_v_schedule()
|
||||
else:
|
||||
scheduler_nodes = None
|
||||
# print(f"{dist.get_rank()} {scheduler_nodes[]} ")
|
||||
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
|
|
|
@ -923,9 +923,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
# (1, 2, 2, 1), # Pass
|
||||
(1, 2, 2, 1), # Pass
|
||||
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
|
||||
(0, 4, 1, 1),
|
||||
# (0, 4, 1, 1),
|
||||
# (1, 2, 1, 2),
|
||||
# (1, 1, 2, 2),
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue