Browse Source

[fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;

pull/6083/head
duanjunwen 1 month ago
parent
commit
9912cc8c07
  1. 9
      colossalai/pipeline/schedule/zero_bubble_pp.py
  2. 1
      colossalai/shardformer/layer/_operation.py
  3. 1
      colossalai/shardformer/layer/linear.py
  4. 11
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py

9
colossalai/pipeline/schedule/zero_bubble_pp.py

@ -509,12 +509,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=input_obj_,
retain_graph=True,
# inputs=input_obj_,
# retain_graph=True,
)
# Format output_obj_grad
input_obj_grad = {}
input_obj_grad = dict()
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
@ -714,7 +713,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# # we save output_tensor_grad here
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# Step2: bwd step
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
@ -761,7 +759,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# get y & dy from buffer
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
WeightGradStore.pop(chunk=model_chunk_id)
# self.backward_w_step(

1
colossalai/shardformer/layer/_operation.py

@ -177,7 +177,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:

1
colossalai/shardformer/layer/linear.py

@ -230,7 +230,6 @@ class Linear1D_Col(ParallelModule):
fp8_communication=self.fp8_communication,
use_zbv=self.use_zbv,
)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(

11
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -753,8 +753,10 @@ def run_with_hybridplugin(test_config):
"config",
[
# TODO:ERR in second iter
# (0, 1, 4, 1, 1),
# (1, 2, 2, 1, 1),
(0, 1, 4, 1, 1),
(1, 2, 2, 1, 1),
(1, 1, 2, 2, 1),
# Pass
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
@ -891,19 +893,22 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter()

Loading…
Cancel
Save