[fix] fix optim bwd;

pull/6034/head
duanjunwen 3 months ago
parent 77fe44286c
commit 591a13bf7e

@ -55,10 +55,10 @@ class OptimizerWrapper:
"""
loss.backward(*args, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)
# def backward_by_grad(self, tensor: Tensor, grad: Tensor):
# torch.autograd.backward(tensor, grad)
def backward_b_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor, retain_graph: bool = False):
"""
Performs a backward pass for dx or dw,
for dx, we only calculate dx = w*dy here
@ -72,12 +72,32 @@ class OptimizerWrapper:
retain_graph (bool): default to be True, we retain graph in backward_b
"""
torch.autograd.backward(
tensors=tensors,
grad_tensors=grad_tensors,
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)
# def backward_b_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
# """
# Performs a backward pass for dx or dw,
# for dx, we only calculate dx = w*dy here
# for dw, we only calculate dw = x*dy here
# Args:
# tensor (Tensor): y or loss of current chunk;
# grad_tensors (Tensor): dy of current chunk;
# input_obj (Tensor): for dx, input_obj is x of current chunk;
# for dw, input_obj is w of current chunk;
# retain_graph (bool): default to be True, we retain graph in backward_b
# """
# torch.autograd.backward(
# tensors=tensors,
# grad_tensors=grad_tensors,
# inputs=inputs,
# retain_graph=retain_graph,
# )
def state_dict(self):
"""
Returns the optimizer state.

@ -441,27 +441,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if model_chunk_id == 0:
# bwd step
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=output_obj_grad,
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
else:
if self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=None,
optimizer.backward_by_grad(
tensor=output_obj,
grad=None,
inputs=input_obj,
retain_graph=True,
)
else:
# commom bwd step
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=output_obj_grad,
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=input_obj,
retain_graph=True,
)
@ -490,25 +490,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# calculate bwd w step ; only dw = x*dy;
if model_chunk_id == 0:
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=output_obj_grad,
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
else:
if self.stage_manager.is_first_stage(ignore_chunk=True):
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=None,
optimizer.backward_by_grad(
tensor=output_obj,
grad=None,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)
else:
optimizer.backward_b_w_by_grad(
tensors=output_obj,
grad_tensors=output_obj_grad,
optimizer.backward_by_grad(
tensor=output_obj,
grad=output_obj_grad,
inputs=list(model_chunk[model_chunk_id].parameters()),
retain_graph=False,
)

@ -14,16 +14,9 @@ from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
class MlpModel(nn.Module):
@ -437,7 +430,7 @@ def run_fwd_bwd_iter_input(test_config):
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.Sequential().to(rank)
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
@ -594,7 +587,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.Sequential().to(rank)
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
local_chunk.append(sub_model)
@ -718,44 +711,46 @@ def run_with_moehybridplugin(test_config):
clear_layout_converter()
torch.set_default_dtype(torch.bfloat16)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name in model_list:
(
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
org_optimizer.step()
sharded_optimizer.step()
# check weights
if test_config["precision"] == "bf16":
atol, rtol = 5e-4, 5e-4
else:
atol, rtol = 5e-4, 5e-4
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check optim states
# check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
print(f"Bert Model Zoo Test Passed")
data_gen_fn()
# print(f"data {data}")
# if name in model_list:
# (
# org_model,
# org_optimizer,
# sharded_model,
# sharded_optimizer,
# criterion,
# booster,
# ) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, torch.optim.SGD, torch.optim.SGD)
# org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
# org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
# )
# stage_manager = booster.plugin.stage_manager
# tp_group = booster.plugin.tp_group
# bert = unwrap_model(org_model, "BertModel", "bert")
# sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
# weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
# org_optimizer.step()
# sharded_optimizer.step()
# # check weights
# if test_config["precision"] == "bf16":
# atol, rtol = 5e-4, 5e-4
# else:
# atol, rtol = 5e-4, 5e-4
# if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
# check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# # check optim states
# # check_dist_optim_state(org_optimizer, sharded_optimizer.optim)
# clear_layout_converter()
# Randomizer.reset_index()
# torch.cuda.empty_cache()
# print(f"Bert Model Zoo Test Passed")
# TODO:6) support booster & Hybrid base 4)
@ -766,8 +761,9 @@ def run_with_moehybridplugin(test_config):
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_fwd_bwd_iter_input()
# run_fwd_bwd_iter_input()
run_fwd_bwd_vschedule_with_optim()
# run_with_moehybridplugin()
@pytest.mark.dist

Loading…
Cancel
Save