mirror of https://github.com/hpcaitech/ColossalAI
[feat] add optim backward_b_by_grad
parent
b1419ef76a
commit
4c4b01b859
|
@ -58,6 +58,28 @@ class OptimizerWrapper:
|
|||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
torch.autograd.backward(tensor, grad)
|
||||
|
||||
def backward_b_by_grad(self, tensor: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True):
|
||||
"""
|
||||
Performs a backward pass for dx, we only calculate dx = w*dy here
|
||||
|
||||
Args:
|
||||
tensor (Tensor): y or loss of current chunk;
|
||||
grad_tensors (Tensor): dy of current chunk;
|
||||
input_obj (Tensor): x of current chunk;
|
||||
retain_graph (bool): default to be True, we retain graph in backward_b
|
||||
"""
|
||||
torch.autograd.backward(
|
||||
tensors=tensor,
|
||||
grad_tensors=grad_tensors,
|
||||
inputs=inputs,
|
||||
retain_graph=retain_graph,
|
||||
)
|
||||
|
||||
def backward_w_by_grad():
|
||||
"""
|
||||
Performs a backward pass for dw, we only calculate dw = x*dy here
|
||||
"""
|
||||
|
||||
def state_dict(self):
|
||||
"""
|
||||
Returns the optimizer state.
|
||||
|
|
|
@ -413,7 +413,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
self,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
optimizer: OptimizerWrapper,
|
||||
input_obj: Optional[dict],
|
||||
output_obj: Union[dict, torch.Tensor],
|
||||
output_obj_grad: Optional[dict],
|
||||
|
@ -447,7 +447,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
|
||||
else:
|
||||
# commom bwd step
|
||||
# BUG:output_obj_grad is None
|
||||
torch.autograd.backward(
|
||||
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
|
||||
)
|
||||
|
@ -564,7 +563,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
scheduled_node,
|
||||
model_chunk: Union[ModuleList, Module],
|
||||
model_chunk_id: int,
|
||||
# optimizer: OptimizerWrapper,
|
||||
optimizer: OptimizerWrapper,
|
||||
# input_obj: Optional[dict],
|
||||
# output_obj: Union[dict, torch.Tensor],
|
||||
# output_obj_grad: Optional[dict],
|
||||
|
@ -614,7 +613,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
input_object_grad = self.backward_b_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
# optimizer: OptimizerWrapper,
|
||||
optimizer=optimizer,
|
||||
input_obj=input_obj,
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_tensor_grad,
|
||||
|
@ -715,6 +714,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
elif scheduled_node.type == "W":
|
||||
self.schedule_w(
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
|
||||
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
@ -625,7 +626,148 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||
batch_size: int,
|
||||
num_model_chunk: int,
|
||||
):
|
||||
pass
|
||||
# init dist
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
rank = dist.get_rank()
|
||||
pp_size = world_size
|
||||
pg_mesh = ProcessGroupMesh(pp_size)
|
||||
num_microbatch = num_microbatch
|
||||
# stage_manager
|
||||
stage_manager = PipelineStageManager(
|
||||
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
||||
)
|
||||
|
||||
h, a, s = 4096, 32, 1024
|
||||
mem_f = 34 * h + 5 * a * s
|
||||
mem_w = -32 * h
|
||||
mem_b = -mem_w - mem_f
|
||||
graph = PipelineGraph(
|
||||
n_stage=world_size,
|
||||
n_micro=num_microbatch,
|
||||
f_cost=6,
|
||||
b_cost=6,
|
||||
w_cost=6,
|
||||
c_cost=6,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
# max_mem=mem_f * (p * 2 + m_offset),
|
||||
)
|
||||
|
||||
zbv_schedule = graph.get_v_schedule()
|
||||
|
||||
scheduler = ZeroBubbleVPipeScheduler(
|
||||
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
||||
stage_manager=stage_manager,
|
||||
num_model_chunks=num_model_chunk,
|
||||
num_microbatch=num_microbatch,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
||||
# init loss func
|
||||
def criterion(x, *args, **kwargs):
|
||||
return (x * x).mean()
|
||||
|
||||
# init model and input
|
||||
batch_size = batch_size
|
||||
num_layers = 8
|
||||
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
||||
in_dim = out_dim = 8
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
|
||||
|
||||
input_base = [t.clone() for t in data_iter]
|
||||
model_base = deepcopy(model)
|
||||
|
||||
if rank == 0:
|
||||
# layer 0 & 7 to chunk 0 on rank0
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 0 or idx == 7:
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 1:
|
||||
# layer 1 & 6 to chunk 1 on rank1
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 1 or idx == 6:
|
||||
local_chunk.append(sub_model)
|
||||
elif rank == 2:
|
||||
# layer 2 & 5 to chunk 2 on rank2
|
||||
local_chunk = torch.nn.ModuleList().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 2 or idx == 5:
|
||||
local_chunk.append(sub_model)
|
||||
else:
|
||||
# layer 3 & 4 to chunk 3 on rank3
|
||||
local_chunk = torch.nn.Sequential().to(rank)
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx == 3 or idx == 4:
|
||||
local_chunk.append(sub_model)
|
||||
|
||||
# init optimizer
|
||||
optimizer_base = torch.optim.SGD(model_base.parameters(), lr=1e-5)
|
||||
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), lr=1e-5))
|
||||
|
||||
print(
|
||||
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
scheduler.run_forward_backward(
|
||||
model_chunk=local_chunk,
|
||||
data_iter=iter(data_iter),
|
||||
criterion=criterion,
|
||||
optimizer=optimizer_pp,
|
||||
return_loss=None,
|
||||
return_outputs=None,
|
||||
)
|
||||
|
||||
##########################
|
||||
# Fwd bwd for base
|
||||
##########################
|
||||
# fwd & bwd
|
||||
output_base = model_base(input_base[0])
|
||||
loss_base = criterion(output_base)
|
||||
loss_base.backward()
|
||||
optimizer_base.step()
|
||||
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
##########################
|
||||
# assert weight
|
||||
##########################
|
||||
if rank == 0:
|
||||
# layer 0
|
||||
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
|
||||
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
|
||||
# layer 7
|
||||
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
|
||||
if rank == 1:
|
||||
# layer 1
|
||||
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
|
||||
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
|
||||
# layer 6
|
||||
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
|
||||
if rank == 2:
|
||||
# layer 2
|
||||
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
|
||||
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
|
||||
# layer 5
|
||||
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
|
||||
if rank == 3:
|
||||
# layer 3
|
||||
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
|
||||
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
|
||||
# layer 4
|
||||
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
|
||||
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
|
||||
|
||||
##########################
|
||||
# assert optim state
|
||||
##########################
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@ -634,8 +776,16 @@ def run_fwd_bwd_vschedule_with_optim(
|
|||
@pytest.mark.parametrize("num_model_chunk", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
||||
# spawn(
|
||||
# run_fwd_bwd_with_vschedule,
|
||||
# nprocs=4,
|
||||
# num_microbatch=num_microbatch,
|
||||
# batch_size=batch_size,
|
||||
# num_model_chunk=num_model_chunk,
|
||||
# )
|
||||
|
||||
spawn(
|
||||
run_fwd_bwd_with_vschedule,
|
||||
run_fwd_bwd_vschedule_with_optim,
|
||||
nprocs=4,
|
||||
num_microbatch=num_microbatch,
|
||||
batch_size=batch_size,
|
||||
|
|
Loading…
Reference in New Issue