mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid;
parent
014afbdb59
commit
5c2ebbfd48
|
@ -46,7 +46,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
num_microbatch: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = False,
|
||||
overlap_p2p: bool = True,
|
||||
):
|
||||
super().__init__(stage_manager)
|
||||
# Not support overlap_p2p so far
|
||||
|
@ -879,12 +879,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||
for it in range(len(schedule)):
|
||||
scheduled_node = schedule[it]
|
||||
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
wait_handle = communication_func(scheduled_node.chunk)
|
||||
self.wait_handles.append(wait_handle)
|
||||
elif scheduled_node.type == "F":
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
|
@ -894,6 +898,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
outputs=outputs,
|
||||
)
|
||||
elif scheduled_node.type == "B":
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
self.schedule_b(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
|
@ -907,7 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
|
|
@ -381,7 +381,6 @@ class MixtralPipelineForwards:
|
|||
output_router_logits,
|
||||
use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
|
|
|
@ -214,7 +214,6 @@ class MixtralPolicy(Policy):
|
|||
suffix="block_sparse_moe.gate",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
|
@ -414,7 +413,6 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||
suffix="lm_head",
|
||||
target_module=LinearWithGradAccum,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
|
|
|
@ -122,7 +122,7 @@ def main():
|
|||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
),
|
||||
"num_layers_per_stage": [19, 20, 20, 21],
|
||||
# "pp_style": "interleaved",
|
||||
"pp_style": "interleaved",
|
||||
}
|
||||
if args.custom_ckpt
|
||||
else {}
|
||||
|
|
Loading…
Reference in New Issue