mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix zbv llama pp4
parent
4fc92aa77d
commit
0d6d40ccc6
|
@ -226,7 +226,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
|
||||||
#################
|
#################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# return None, []
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
|
@ -241,7 +240,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||||
# return input_tensor, wait_handles
|
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -265,7 +263,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
|
||||||
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
|
||||||
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
self.recv_forward_buffer[model_chunk_id].append(input_tensor)
|
||||||
# return input_tensor, wait_handles
|
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||||
|
@ -313,7 +310,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# do nothing; get loss from local
|
# do nothing; get loss from local
|
||||||
################
|
################
|
||||||
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
if self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
# return None, []
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
################
|
################
|
||||||
|
@ -328,7 +324,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
|
||||||
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
|
||||||
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad)
|
||||||
# return output_tensor_grad, wait_handles
|
|
||||||
return wait_handles
|
return wait_handles
|
||||||
|
|
||||||
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
|
||||||
|
@ -665,7 +660,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
accum_loss=accum_loss,
|
accum_loss=accum_loss,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
# print(f"stage {self.stage_manager.stage}; model_chunk_id {model_chunk_id}; output_obj {output_obj};")
|
|
||||||
|
|
||||||
# Step3:
|
# Step3:
|
||||||
# 3-1:detach output; detach output for send fwd;
|
# 3-1:detach output; detach output for send fwd;
|
||||||
|
@ -748,20 +742,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||||
|
|
||||||
# # save output_tensor_grad for dw
|
|
||||||
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
|
||||||
# # we save loss here
|
|
||||||
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
|
|
||||||
# else:
|
|
||||||
# # we save output_tensor_grad here
|
|
||||||
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
|
||||||
# the_output_obj_grad = []
|
|
||||||
# if isinstance(output_obj, dict):
|
|
||||||
# for (k, v) in output_obj.items():
|
|
||||||
# the_output_obj_grad.append(v.requires_grad)
|
|
||||||
# else:
|
|
||||||
# the_output_obj_grad.append(output_obj.requires_grad)
|
|
||||||
|
|
||||||
input_object_grad = self.backward_b_step(
|
input_object_grad = self.backward_b_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
|
@ -804,20 +784,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
Returns:
|
Returns:
|
||||||
Nothing.
|
Nothing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# get y & dy from buffer
|
|
||||||
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
|
||||||
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
|
||||||
WeightGradStore.pop(chunk=model_chunk_id)
|
WeightGradStore.pop(chunk=model_chunk_id)
|
||||||
|
|
||||||
# self.backward_w_step(
|
|
||||||
# model_chunk=model_chunk,
|
|
||||||
# model_chunk_id=model_chunk_id,
|
|
||||||
# optimizer=optimizer,
|
|
||||||
# output_obj=output_obj,
|
|
||||||
# output_obj_grad=output_obj_grad,
|
|
||||||
# )
|
|
||||||
|
|
||||||
def run_forward_only(
|
def run_forward_only(
|
||||||
self,
|
self,
|
||||||
model_chunk: Union[ModuleList, Module],
|
model_chunk: Union[ModuleList, Module],
|
||||||
|
@ -890,7 +858,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||||
for it in range(len(schedule)):
|
for it in range(len(schedule)):
|
||||||
scheduled_node = schedule[it]
|
scheduled_node = schedule[it]
|
||||||
# print(f"rank {torch.distributed.get_rank()}; stage {self.stage_manager.stage}; scheduled_node {scheduled_node};")
|
|
||||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||||
# communication
|
# communication
|
||||||
communication_func = self.communication_map[scheduled_node.type]
|
communication_func = self.communication_map[scheduled_node.type]
|
||||||
|
|
|
@ -749,12 +749,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
||||||
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
|
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
|
||||||
|
|
||||||
|
|
||||||
# TODO:3) support booster & Hybrid base 2)
|
|
||||||
def run_with_hybridplugin(test_config):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# TODO:4) support booster & MoEHybrid base 2)
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
|
@ -923,9 +917,9 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
# # Pass
|
# # Pass
|
||||||
(1, 2, 2, 1),
|
# (1, 2, 2, 1),
|
||||||
(1, 2, 1, 2),
|
# (1, 2, 1, 2),
|
||||||
(1, 1, 2, 2),
|
# (1, 1, 2, 2),
|
||||||
# TODO: acc err in pp4
|
# TODO: acc err in pp4
|
||||||
(1, 4, 1, 1),
|
(1, 4, 1, 1),
|
||||||
],
|
],
|
||||||
|
@ -1071,6 +1065,17 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
|
# assert param
|
||||||
|
for parall_name, parall_param in parallel_model.named_parameters():
|
||||||
|
parall_name = ".".join(parall_name.split(".")[1:])
|
||||||
|
for base_name, base_param in torch_model.named_parameters():
|
||||||
|
if parall_name == base_name:
|
||||||
|
# assert weight
|
||||||
|
assert_loose_close(parall_param, base_param, dtype=dtype, name=parall_name)
|
||||||
|
# assert weight.grad
|
||||||
|
if parall_param.grad is not None:
|
||||||
|
assert_loose_close(parall_param.grad, base_param.grad, dtype=dtype, name=f"{parall_name}.grad")
|
||||||
|
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
|
@ -1081,7 +1086,7 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
run_with_booster_moehybridplugin()
|
# run_with_booster_moehybridplugin()
|
||||||
run_with_booster_hybridplugin()
|
run_with_booster_hybridplugin()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -420,4 +420,4 @@ def test_llama_3d():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_llama()
|
test_llama()
|
||||||
test_llama_3d()
|
# test_llama_3d()
|
||||||
|
|
Loading…
Reference in New Issue