Merge pull request #6069 from duanjunwen/dev/zero_bubble

[HotFix] Fix stage_index in zerobubble test;
pull/6075/head
duanjunwen 2024-09-27 10:34:04 +08:00 committed by GitHub
commit b804fdc297
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 122 additions and 140 deletions

View File

@ -430,8 +430,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs)
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@ -449,7 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
micro_batch: Optional[dict],
# micro_batch: Optional[dict],
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
@ -478,11 +478,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_ = []
output_obj_grad_ = []
# For chunk 0 stage 0, use micro_batch as input_obj_
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj_, _ = tree_flatten(micro_batch)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
return None
# For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@ -497,6 +495,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
@ -507,9 +510,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Format output_obj_grad
input_obj_grad = {}
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
pass
else:
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
@ -550,10 +551,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=list(model_chunk[model_chunk_id].parameters()),
inputs=list(model_chunk.parameters()),
retain_graph=False,
)
@ -634,7 +639,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
tree_map(release_tensor_data, output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append((micro_batch, input_obj))
self.input_tensors[model_chunk_id].append(input_obj)
# for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not release_tensor_data loss, release_tensor_data other output_obj;
@ -692,7 +697,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# get input and output object from buffer;
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0)
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw
@ -708,7 +713,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
micro_batch=micro_batch,
input_obj=input_obj,
output_obj=output_obj,
output_obj_grad=output_tensor_grad,

View File

@ -26,6 +26,7 @@ class PipelineStageManager:
pg_mesh: ProcessGroupMesh,
pipeline_axis: int,
enable_interleave: bool = False,
use_zbv: bool = False,
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
) -> None:
@ -49,6 +50,7 @@ class PipelineStageManager:
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
self.is_interleave = enable_interleave
self.use_zbv = use_zbv
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model
@ -85,6 +87,16 @@ class PipelineStageManager:
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
stage_indices = []
if self.use_zbv:
stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]])
stage_indices.append(
[
num_layers_per_stage_accumulated[2 * num_stages - stage - 1],
num_layers_per_stage_accumulated[2 * num_stages - stage],
]
)
return stage_indices
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]

View File

@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
self.is_interleave = False
self.num_layers_per_stage = None
self.num_model_chunks = 1
self.use_zbv = False
@property
def num_stages(self):

View File

@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
self.is_interleave = False
self.num_layers_per_stage = None
self.num_model_chunks = 1
self.use_zbv = False
@property
def num_stages(self):

View File

@ -1,6 +1,5 @@
from copy import deepcopy
from functools import partial
from types import MethodType
from typing import Tuple
import pytest
@ -22,36 +21,50 @@ from tests.kit.model_zoo import model_zoo
class MlpModel(nn.Module):
def __init__(self, in_dim, out_dim, num_layers):
def __init__(
self,
in_dim,
out_dim,
num_layers,
stage_index=None,
stage_mgr: PipelineStageManager = None,
):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
self.layers = nn.Sequential(*[nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
def forward(
self,
hidden_states,
data: torch.Tensor = None,
hidden_states: torch.Tensor = None,
stage_index=None,
stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None,
):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
def pp_linear_fwd(
forward,
data: torch.Tensor = None,
hidden_states: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None,
):
with stage_mgr.switch_model_chunk_id(model_chunk_id):
# fwd end
if stage_mgr.is_first_stage() and model_chunk_id == 1:
return forward(hidden_states)
# fwd start
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
return {"hidden_states": forward(data)}
# fwd middle
if stage_mgr is None:
hidden_states = data
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
else:
return {"hidden_states": forward(hidden_states)}
# Set not used layer to None
held_layers = self.layers[stage_index[0] : stage_index[1]]
# fwd end
if stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 1:
return held_layers(hidden_states)
# fwd start
elif stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 0:
return {"hidden_states": held_layers(data)}
# fwd middle
else:
return {"hidden_states": held_layers(hidden_states)}
def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups):
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
if key_base == key_pp:
if key_base != "params":
assert val_base == val_pp
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
@ -554,7 +567,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
num_model_chunk = test_config["num_model_chunk"]
# stage_manager
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk, use_zbv=True
)
h, a, s = 4096, 32, 1024
@ -600,67 +613,27 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
before_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
# input_base = [t.clone() for t in data_iter]
input_base = {k: v.clone() for k, v in data_iter.items()}
model_base = deepcopy(model)
model_pp = deepcopy(model)
layers_per_stage = stage_manager.distribute_layers(len(model.layers))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
if rank == 0:
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
model_pp._forward = model_pp.forward
model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager)
# init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5))
optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
after_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
torch.cuda.synchronize()
result = scheduler.forward_backward_step(
model_chunk=local_chunk,
model_chunk=model_pp,
data_iter=iter([data_iter]),
criterion=criterion,
optimizer=optimizer_pp,
@ -694,7 +667,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# Fwd bwd for base
##########################
# fwd & bwd
output_base = model_base(input_base["data"])
# output_base = model_base(input_base["data"])
output_base = model_base.forward(data=input_base["data"])
loss_base = criterion_base(output_base)
loss_base.backward()
optimizer_base.step()
@ -707,63 +681,53 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
assert_close(result["loss"], loss_base)
assert_close(result["outputs"]["hidden_states"], output_base)
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ")
##########################
# assert weight
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
##########################
# assert optim state
##########################
# ##########################
# # assert weight & optim state
# ##########################
optim_base_state = optimizer_base.state_dict()["state"]
optim_pp_state = optimizer_pp.state_dict()["state"]
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
# if rank == 0:
# print(f"optim_base_state {optim_base_state}")
# assert param group
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
if key_base == key_pp:
if key_base != "params":
assert val_base == val_pp
else:
# BUG:
# param_base: [0, 1, 2, 3, 4, 5, 6, 7];
# params pp: [0, 1];
assert val_base[:2] == val_pp
if rank == 0:
# layer 0
assert_close(model_pp.layers[0].weight, model_base.layers[0].weight)
assert_close(model_pp.layers[0].weight.grad, model_base.layers[0].weight.grad)
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[0]["momentum_buffer"])
# layer 7
assert_close(model_pp.layers[7].weight, model_base.layers[7].weight)
assert_close(model_pp.layers[7].weight.grad, model_base.layers[7].weight.grad)
assert_close(optim_pp_state[7]["momentum_buffer"], optim_base_state[7]["momentum_buffer"])
if rank == 1:
# layer 1
assert_close(model_pp.layers[1].weight, model_base.layers[1].weight)
assert_close(model_pp.layers[1].weight.grad, model_base.layers[1].weight.grad)
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[1]["momentum_buffer"])
# layer 6
assert_close(model_pp.layers[6].weight, model_base.layers[6].weight)
assert_close(model_pp.layers[6].weight.grad, model_base.layers[6].weight.grad)
assert_close(optim_pp_state[6]["momentum_buffer"], optim_base_state[6]["momentum_buffer"])
if rank == 2:
# layer 2
assert_close(model_pp.layers[2].weight, model_base.layers[2].weight)
assert_close(model_pp.layers[2].weight.grad, model_base.layers[2].weight.grad)
assert_close(optim_pp_state[2]["momentum_buffer"], optim_base_state[2]["momentum_buffer"])
# layer 5
assert_close(model_pp.layers[5].weight, model_base.layers[5].weight)
assert_close(model_pp.layers[5].weight.grad, model_base.layers[5].weight.grad)
assert_close(optim_pp_state[5]["momentum_buffer"], optim_base_state[5]["momentum_buffer"])
if rank == 3:
# layer 3
assert_close(model_pp.layers[3].weight, model_base.layers[3].weight)
assert_close(model_pp.layers[3].weight.grad, model_base.layers[3].weight.grad)
assert_close(optim_pp_state[3]["momentum_buffer"], optim_base_state[3]["momentum_buffer"])
# layer 4
assert_close(model_pp.layers[4].weight, model_base.layers[4].weight)
assert_close(model_pp.layers[4].weight.grad, model_base.layers[4].weight.grad)
assert_close(optim_pp_state[4]["momentum_buffer"], optim_base_state[4]["momentum_buffer"])
# assert state
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"])
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
# assert optim param_groups
assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
# TODO:4) support Hybrid base 3)