mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix test_shard_llama
parent
d0ec221b38
commit
cc0dfddcbc
|
@ -394,8 +394,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
# if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
|
if self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv:
|
||||||
# return []
|
return []
|
||||||
llama_model = self.model.model
|
llama_model = self.model.model
|
||||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||||
if (
|
if (
|
||||||
|
@ -403,20 +403,26 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
and self.pipeline_stage_manager.num_stages > 1
|
and self.pipeline_stage_manager.num_stages > 1
|
||||||
):
|
):
|
||||||
# tie weights
|
# tie weights
|
||||||
if self.pipeline_stage_manager.use_zbv:
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
0: llama_model.embed_tokens.weight,
|
|
||||||
0: self.model.lm_head.weight,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
0: llama_model.embed_tokens.weight,
|
0: llama_model.embed_tokens.weight,
|
||||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
# if self.pipeline_stage_manager.use_zbv:
|
||||||
|
# return [
|
||||||
|
# {
|
||||||
|
# 0: llama_model.embed_tokens.weight,
|
||||||
|
# 0: self.model.lm_head.weight,
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# else:
|
||||||
|
# return [
|
||||||
|
# {
|
||||||
|
# 0: llama_model.embed_tokens.weight,
|
||||||
|
# self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -237,7 +237,6 @@ def main():
|
||||||
).get_v_schedule()
|
).get_v_schedule()
|
||||||
else:
|
else:
|
||||||
scheduler_nodes = None
|
scheduler_nodes = None
|
||||||
# print(f"{dist.get_rank()} {scheduler_nodes[]} ")
|
|
||||||
|
|
||||||
plugin = HybridParallelPlugin(
|
plugin = HybridParallelPlugin(
|
||||||
tp_size=args.tp,
|
tp_size=args.tp,
|
||||||
|
|
|
@ -923,9 +923,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
# (1, 2, 2, 1), # Pass
|
(1, 2, 2, 1), # Pass
|
||||||
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
|
# TODO: only support pp + tp accleration; Will support fully pp and None tp Hybrid in furture;
|
||||||
(0, 4, 1, 1),
|
# (0, 4, 1, 1),
|
||||||
# (1, 2, 1, 2),
|
# (1, 2, 1, 2),
|
||||||
# (1, 1, 2, 2),
|
# (1, 1, 2, 2),
|
||||||
],
|
],
|
||||||
|
|
Loading…
Reference in New Issue