mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;
parent
160e9a4175
commit
9912cc8c07
|
@ -509,12 +509,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
optimizer.backward_by_grad(
|
optimizer.backward_by_grad(
|
||||||
tensor=output_obj_,
|
tensor=output_obj_,
|
||||||
grad=output_obj_grad_,
|
grad=output_obj_grad_,
|
||||||
inputs=input_obj_,
|
# inputs=input_obj_,
|
||||||
retain_graph=True,
|
# retain_graph=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format output_obj_grad
|
# Format output_obj_grad
|
||||||
input_obj_grad = {}
|
input_obj_grad = dict()
|
||||||
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
@ -714,7 +713,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# # we save output_tensor_grad here
|
# # we save output_tensor_grad here
|
||||||
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||||
|
|
||||||
# Step2: bwd step
|
|
||||||
input_object_grad = self.backward_b_step(
|
input_object_grad = self.backward_b_step(
|
||||||
model_chunk=model_chunk,
|
model_chunk=model_chunk,
|
||||||
model_chunk_id=model_chunk_id,
|
model_chunk_id=model_chunk_id,
|
||||||
|
@ -761,7 +759,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||||
# get y & dy from buffer
|
# get y & dy from buffer
|
||||||
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||||
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
||||||
|
|
||||||
WeightGradStore.pop(chunk=model_chunk_id)
|
WeightGradStore.pop(chunk=model_chunk_id)
|
||||||
|
|
||||||
# self.backward_w_step(
|
# self.backward_w_step(
|
||||||
|
|
|
@ -177,7 +177,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
|
||||||
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||||
|
|
||||||
if _grad_accum_fusion_available and weight.grad is not None:
|
if _grad_accum_fusion_available and weight.grad is not None:
|
||||||
grad = weight.grad
|
grad = weight.grad
|
||||||
if use_zbv:
|
if use_zbv:
|
||||||
|
|
|
@ -230,7 +230,6 @@ class Linear1D_Col(ParallelModule):
|
||||||
fp8_communication=self.fp8_communication,
|
fp8_communication=self.fp8_communication,
|
||||||
use_zbv=self.use_zbv,
|
use_zbv=self.use_zbv,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = gather_forward_split_backward(
|
output = gather_forward_split_backward(
|
||||||
|
|
|
@ -753,8 +753,10 @@ def run_with_hybridplugin(test_config):
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
# TODO:ERR in second iter
|
# TODO:ERR in second iter
|
||||||
# (0, 1, 4, 1, 1),
|
(0, 1, 4, 1, 1),
|
||||||
# (1, 2, 2, 1, 1),
|
(1, 2, 2, 1, 1),
|
||||||
|
(1, 1, 2, 2, 1),
|
||||||
|
# Pass
|
||||||
(1, 2, 1, 2, 1),
|
(1, 2, 1, 2, 1),
|
||||||
(1, 2, 1, 1, 2),
|
(1, 2, 1, 1, 2),
|
||||||
],
|
],
|
||||||
|
@ -891,19 +893,22 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
|
|
||||||
# ===================================================================================
|
# ===================================================================================
|
||||||
# run normal model with all dp(different) inputs
|
# run normal model with all dp(different) inputs
|
||||||
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||||
torch_output_sum = 0
|
torch_output_sum = 0
|
||||||
for input_data_ in all_inputs:
|
for input_data_ in all_inputs:
|
||||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
torch_output.backward()
|
torch_output.backward()
|
||||||
torch_output_sum += torch_output.detach()
|
torch_output_sum += torch_output.detach()
|
||||||
|
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||||
# avg dp grads follows zero optimizer
|
# avg dp grads follows zero optimizer
|
||||||
for p in torch_model.parameters():
|
for p in torch_model.parameters():
|
||||||
if p.grad is not None:
|
if p.grad is not None:
|
||||||
p.grad /= dp_size
|
p.grad /= dp_size
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
torch_optimizer.zero_grad()
|
torch_optimizer.zero_grad()
|
||||||
|
|
||||||
|
# print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
print(f"rank {dist.get_rank()} config {test_config} test passed")
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
|
|
Loading…
Reference in New Issue