[fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;

pull/6083/head
duanjunwen 2024-10-15 06:26:01 +00:00
parent 160e9a4175
commit 9912cc8c07
4 changed files with 11 additions and 11 deletions

View File

@ -509,12 +509,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,
grad=output_obj_grad_, grad=output_obj_grad_,
inputs=input_obj_, # inputs=input_obj_,
retain_graph=True, # retain_graph=True,
) )
# Format output_obj_grad # Format output_obj_grad
input_obj_grad = {} input_obj_grad = dict()
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
pass pass
else: else:
@ -714,7 +713,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# # we save output_tensor_grad here # # we save output_tensor_grad here
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) # self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# Step2: bwd step
input_object_grad = self.backward_b_step( input_object_grad = self.backward_b_step(
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=model_chunk_id, model_chunk_id=model_chunk_id,
@ -761,7 +759,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# get y & dy from buffer # get y & dy from buffer
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0) # output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) # output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
WeightGradStore.pop(chunk=model_chunk_id) WeightGradStore.pop(chunk=model_chunk_id)
# self.backward_w_step( # self.backward_w_step(

View File

@ -177,7 +177,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None: if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad grad = weight.grad
if use_zbv: if use_zbv:

View File

@ -230,7 +230,6 @@ class Linear1D_Col(ParallelModule):
fp8_communication=self.fp8_communication, fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv, use_zbv=self.use_zbv,
) )
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_forward_split_backward( output = gather_forward_split_backward(

View File

@ -753,8 +753,10 @@ def run_with_hybridplugin(test_config):
"config", "config",
[ [
# TODO:ERR in second iter # TODO:ERR in second iter
# (0, 1, 4, 1, 1), (0, 1, 4, 1, 1),
# (1, 2, 2, 1, 1), (1, 2, 2, 1, 1),
(1, 1, 2, 2, 1),
# Pass
(1, 2, 1, 2, 1), (1, 2, 1, 2, 1),
(1, 2, 1, 1, 2), (1, 2, 1, 1, 2),
], ],
@ -891,19 +893,22 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
# =================================================================================== # ===================================================================================
# run normal model with all dp(different) inputs # run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0 torch_output_sum = 0
for input_data_ in all_inputs: for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward() torch_output.backward()
torch_output_sum += torch_output.detach() torch_output_sum += torch_output.detach()
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
# avg dp grads follows zero optimizer # avg dp grads follows zero optimizer
for p in torch_model.parameters(): for p in torch_model.parameters():
if p.grad is not None: if p.grad is not None:
p.grad /= dp_size p.grad /= dp_size
torch_optimizer.step() torch_optimizer.step()
torch_optimizer.zero_grad() torch_optimizer.zero_grad()
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} config {test_config} test passed") print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter() clear_layout_converter()