Browse Source

[fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid;

pull/6114/head
duanjunwen 7 days ago
parent
commit
5c2ebbfd48
  1. 13
      colossalai/pipeline/schedule/zero_bubble_pp.py
  2. 1
      colossalai/shardformer/modeling/mixtral.py
  3. 2
      colossalai/shardformer/policies/mixtral.py
  4. 2
      examples/language/mixtral/benchmark.py

13
colossalai/pipeline/schedule/zero_bubble_pp.py

@ -46,7 +46,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = False,
overlap_p2p: bool = True,
):
super().__init__(stage_manager)
# Not support overlap_p2p so far
@ -879,12 +879,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)):
scheduled_node = schedule[it]
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]
wait_handle = communication_func(scheduled_node.chunk)
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F":
for h in self.wait_handles:
for hh in h:
hh.wait()
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
@ -894,6 +898,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs=outputs,
)
elif scheduled_node.type == "B":
for h in self.wait_handles:
for hh in h:
hh.wait()
self.schedule_b(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
@ -907,7 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)

1
colossalai/shardformer/modeling/mixtral.py

@ -381,7 +381,6 @@ class MixtralPipelineForwards:
output_router_logits,
use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:

2
colossalai/shardformer/policies/mixtral.py

@ -214,7 +214,6 @@ class MixtralPolicy(Policy):
suffix="block_sparse_moe.gate",
target_module=LinearWithGradAccum,
kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": use_zbv,
},
@ -414,7 +413,6 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
suffix="lm_head",
target_module=LinearWithGradAccum,
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),

2
examples/language/mixtral/benchmark.py

@ -122,7 +122,7 @@ def main():
num_ckpt_layers_per_stage=[19, 19, 19, 13],
),
"num_layers_per_stage": [19, 20, 20, 21],
# "pp_style": "interleaved",
"pp_style": "interleaved",
}
if args.custom_ckpt
else {}

Loading…
Cancel
Save