Browse Source

[fix] fix zbv llama pp4

pull/6114/head
duanjunwen 2 weeks ago
parent
commit
0d6d40ccc6
  1. 33
      colossalai/pipeline/schedule/zero_bubble_pp.py
  2. 25
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py
  3. 2
      tests/test_shardformer/test_model/test_shard_llama.py

33
colossalai/pipeline/schedule/zero_bubble_pp.py

@ -226,7 +226,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; cause u are chunk 0 in first rank, u have no prev rank; # do nothing; cause u are chunk 0 in first rank, u have no prev rank;
################# #################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
# return None, []
return [] return []
################ ################
@ -241,7 +240,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor)
# return input_tensor, wait_handles
return wait_handles return wait_handles
else: else:
@ -265,7 +263,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append(input_tensor) self.recv_forward_buffer[model_chunk_id].append(input_tensor)
# return input_tensor, wait_handles
return wait_handles return wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
@ -313,7 +310,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# do nothing; get loss from local # do nothing; get loss from local
################ ################
if self.stage_manager.is_first_stage(ignore_chunk=True): if self.stage_manager.is_first_stage(ignore_chunk=True):
# return None, []
return [] return []
################ ################
@ -328,7 +324,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
# return output_tensor_grad, wait_handles
return wait_handles return wait_handles
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
@ -665,7 +660,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss, accum_loss=accum_loss,
outputs=outputs, outputs=outputs,
) )
# print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};")
# Step3: # Step3:
# 3-1:detach output; detach output for send fwd; # 3-1:detach output; detach output for send fwd;
@ -748,20 +742,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = self.input_tensors[model_chunk_id].pop(0) input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0)
# # save output_tensor_grad for dw
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# # we save loss here
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
# else:
# # we save output_tensor_grad here
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# the_output_obj_grad = []
# if isinstance(output_obj, dict):
# for (k, v) in output_obj.items():
# the_output_obj_grad.append(v.requires_grad)
# else:
# the_output_obj_grad.append(output_obj.requires_grad)
input_object_grad = self.backward_b_step( input_object_grad = self.backward_b_step(
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=model_chunk_id, model_chunk_id=model_chunk_id,
@ -804,20 +784,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
Returns: Returns:
Nothing. Nothing.
""" """
# get y & dy from buffer
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
WeightGradStore.pop(chunk=model_chunk_id) WeightGradStore.pop(chunk=model_chunk_id)
# self.backward_w_step(
# model_chunk=model_chunk,
# model_chunk_id=model_chunk_id,
# optimizer=optimizer,
# output_obj=output_obj,
# output_obj_grad=output_obj_grad,
# )
def run_forward_only( def run_forward_only(
self, self,
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
@ -890,7 +858,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)): for it in range(len(schedule)):
scheduled_node = schedule[it] scheduled_node = schedule[it]
# print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};")
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication # communication
communication_func = self.communication_map[scheduled_node.type] communication_func = self.communication_map[scheduled_node.type]

25
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -749,12 +749,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
# TODO:3) support booster & Hybrid base 2)
def run_with_hybridplugin(test_config):
pass
# TODO:4) support booster & MoEHybrid base 2)
@parameterize( @parameterize(
"config", "config",
[ [
@ -923,9 +917,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
"config", "config",
[ [
# # Pass # # Pass
(1, 2, 2, 1), # (1, 2, 2, 1),
(1, 2, 1, 2), # (1, 2, 1, 2),
(1, 1, 2, 2), # (1, 1, 2, 2),
# TODO: acc err in pp4 # TODO: acc err in pp4
(1, 4, 1, 1), (1, 4, 1, 1),
], ],
@ -1071,6 +1065,17 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
torch_optimizer.step() torch_optimizer.step()
torch_optimizer.zero_grad() torch_optimizer.zero_grad()
# assert param
for parall_name, parall_param in parallel_model.named_parameters():
parall_name = ".".join(parall_name.split(".")[1:])
for base_name, base_param in torch_model.named_parameters():
if parall_name == base_name:
# assert weight
assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name)
# assert weight.grad
if parall_param.grad is not None:
assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad")
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
clear_layout_converter() clear_layout_converter()
@ -1081,7 +1086,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_with_booster_moehybridplugin() # run_with_booster_moehybridplugin()
run_with_booster_hybridplugin() run_with_booster_hybridplugin()

2
tests/test_shardformer/test_model/test_shard_llama.py

@ -420,4 +420,4 @@ def test_llama_3d():
if __name__ == "__main__": if __name__ == "__main__":
test_llama() test_llama()
test_llama_3d() # test_llama_3d()

Loading…
Cancel
Save