Merge pull request #6069 from duanjunwen/dev/zero_bubble

[HotFix] Fix stage_index in zerobubble test;
pull/6075/head
duanjunwen 2024-09-27 10:34:04 +08:00 committed by GitHub
commit b804fdc297
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 122 additions and 140 deletions

View File

@ -430,8 +430,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
with self.stage_manager.switch_model_chunk_id(model_chunk_id): with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate # fwd calculate
internal_inputs = {} if input_obj is None else input_obj internal_inputs = {} if input_obj is None else input_obj
# internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model # last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@ -449,7 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk: Union[ModuleList, Module], model_chunk: Union[ModuleList, Module],
model_chunk_id: int, model_chunk_id: int,
optimizer: OptimizerWrapper, optimizer: OptimizerWrapper,
micro_batch: Optional[dict], # micro_batch: Optional[dict],
input_obj: Optional[dict], input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor], output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict], output_obj_grad: Optional[dict],
@ -478,11 +478,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_ = [] output_obj_ = []
output_obj_grad_ = [] output_obj_grad_ = []
# For chunk 0 stage 0, use micro_batch as input_obj_ # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj_, _ = tree_flatten(micro_batch) return None
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# For loss backward; output_obj is loss; output_obj_grad should be None # For loss backward; output_obj is loss; output_obj_grad should be None
elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@ -497,6 +495,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_, _ = tree_flatten(output_obj) # y output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,
grad=output_obj_grad_, grad=output_obj_grad_,
@ -507,9 +510,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Format output_obj_grad # Format output_obj_grad
input_obj_grad = {} input_obj_grad = {}
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
for k, v in micro_batch.items(): pass
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
else: else:
for k, v in input_obj.items(): for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None: if isinstance(v, torch.Tensor) and v.grad is not None:
@ -550,10 +551,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_obj_, _ = tree_flatten(output_obj) # y output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,
grad=output_obj_grad_, grad=output_obj_grad_,
inputs=list(model_chunk[model_chunk_id].parameters()), inputs=list(model_chunk.parameters()),
retain_graph=False, retain_graph=False,
) )
@ -634,7 +639,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
tree_map(release_tensor_data, output_obj) tree_map(release_tensor_data, output_obj)
# add input and output object for backward b # add input and output object for backward b
self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) self.input_tensors[model_chunk_id].append(input_obj)
# for bwd b&w, we only need the graph(grad_fn) of output_obj # for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not release_tensor_data loss, release_tensor_data other output_obj; # Do not release_tensor_data loss, release_tensor_data other output_obj;
@ -692,7 +697,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
# get input and output object from buffer; # get input and output object from buffer;
micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw # save output_tensor_grad for dw
@ -708,7 +713,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk=model_chunk, model_chunk=model_chunk,
model_chunk_id=model_chunk_id, model_chunk_id=model_chunk_id,
optimizer=optimizer, optimizer=optimizer,
micro_batch=micro_batch,
input_obj=input_obj, input_obj=input_obj,
output_obj=output_obj, output_obj=output_obj,
output_obj_grad=output_tensor_grad, output_obj_grad=output_tensor_grad,

View File

@ -26,6 +26,7 @@ class PipelineStageManager:
pg_mesh: ProcessGroupMesh, pg_mesh: ProcessGroupMesh,
pipeline_axis: int, pipeline_axis: int,
enable_interleave: bool = False, enable_interleave: bool = False,
use_zbv: bool = False,
num_model_chunks: int = 1, num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None, num_layers_per_stage: Optional[List[int]] = None,
) -> None: ) -> None:
@ -49,6 +50,7 @@ class PipelineStageManager:
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
self.is_interleave = enable_interleave self.is_interleave = enable_interleave
self.use_zbv = use_zbv
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model # for shardformer, hold stage indices of model
@ -85,6 +87,16 @@ class PipelineStageManager:
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
stage_indices = [] stage_indices = []
if self.use_zbv:
stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]])
stage_indices.append(
[
num_layers_per_stage_accumulated[2 * num_stages - stage - 1],
num_layers_per_stage_accumulated[2 * num_stages - stage],
]
)
return stage_indices
for model_chunk in range(num_model_chunks): for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]

View File

@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
self.is_interleave = False self.is_interleave = False
self.num_layers_per_stage = None self.num_layers_per_stage = None
self.num_model_chunks = 1 self.num_model_chunks = 1
self.use_zbv = False
@property @property
def num_stages(self): def num_stages(self):

View File

@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager):
self.is_interleave = False self.is_interleave = False
self.num_layers_per_stage = None self.num_layers_per_stage = None
self.num_model_chunks = 1 self.num_model_chunks = 1
self.use_zbv = False
@property @property
def num_stages(self): def num_stages(self):

View File

@ -1,6 +1,5 @@
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from types import MethodType
from typing import Tuple from typing import Tuple
import pytest import pytest
@ -22,36 +21,50 @@ from tests.kit.model_zoo import model_zoo
class MlpModel(nn.Module): class MlpModel(nn.Module):
def __init__(self, in_dim, out_dim, num_layers): def __init__(
self,
in_dim,
out_dim,
num_layers,
stage_index=None,
stage_mgr: PipelineStageManager = None,
):
super().__init__() super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) self.layers = nn.Sequential(*[nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
def forward( def forward(
self, self,
hidden_states,
):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
def pp_linear_fwd(
forward,
data: torch.Tensor = None, data: torch.Tensor = None,
hidden_states: torch.Tensor = None, hidden_states: torch.Tensor = None,
stage_index=None,
stage_mgr: PipelineStageManager = None, stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None, model_chunk_id: int = None,
): ):
with stage_mgr.switch_model_chunk_id(model_chunk_id): if stage_mgr is None:
hidden_states = data
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
else:
# Set not used layer to None
held_layers = self.layers[stage_index[0] : stage_index[1]]
# fwd end # fwd end
if stage_mgr.is_first_stage() and model_chunk_id == 1: if stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 1:
return forward(hidden_states) return held_layers(hidden_states)
# fwd start # fwd start
elif stage_mgr.is_first_stage() and model_chunk_id == 0: elif stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 0:
return {"hidden_states": forward(data)} return {"hidden_states": held_layers(data)}
# fwd middle # fwd middle
else: else:
return {"hidden_states": forward(hidden_states)} return {"hidden_states": held_layers(hidden_states)}
def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups):
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
if key_base == key_pp:
if key_base != "params":
assert val_base == val_pp
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
@ -554,7 +567,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
num_model_chunk = test_config["num_model_chunk"] num_model_chunk = test_config["num_model_chunk"]
# stage_manager # stage_manager
stage_manager = PipelineStageManager( stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk, use_zbv=True
) )
h, a, s = 4096, 32, 1024 h, a, s = 4096, 32, 1024
@ -600,67 +613,27 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
before_init_memory = torch.cuda.memory_allocated() / 1024**3 before_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
# input_base = [t.clone() for t in data_iter]
input_base = {k: v.clone() for k, v in data_iter.items()} input_base = {k: v.clone() for k, v in data_iter.items()}
model_base = deepcopy(model) model_base = deepcopy(model)
model_pp = deepcopy(model)
layers_per_stage = stage_manager.distribute_layers(len(model.layers))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
if rank == 0: model_pp._forward = model_pp.forward
# layer 0 & 7 to chunk 0 on rank0
local_chunk = torch.nn.ModuleList().to(rank) model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager)
for idx, sub_model in enumerate(model.layers):
if idx == 0 or idx == 7:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
# init optimizer # init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5)) optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
after_init_memory = torch.cuda.memory_allocated() / 1024**3 after_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
torch.cuda.synchronize() torch.cuda.synchronize()
result = scheduler.forward_backward_step( result = scheduler.forward_backward_step(
model_chunk=local_chunk, model_chunk=model_pp,
data_iter=iter([data_iter]), data_iter=iter([data_iter]),
criterion=criterion, criterion=criterion,
optimizer=optimizer_pp, optimizer=optimizer_pp,
@ -694,7 +667,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# Fwd bwd for base # Fwd bwd for base
########################## ##########################
# fwd & bwd # fwd & bwd
output_base = model_base(input_base["data"]) # output_base = model_base(input_base["data"])
output_base = model_base.forward(data=input_base["data"])
loss_base = criterion_base(output_base) loss_base = criterion_base(output_base)
loss_base.backward() loss_base.backward()
optimizer_base.step() optimizer_base.step()
@ -707,63 +681,53 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
assert_close(result["loss"], loss_base) assert_close(result["loss"], loss_base)
assert_close(result["outputs"]["hidden_states"], output_base) assert_close(result["outputs"]["hidden_states"], output_base)
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") # ##########################
########################## # # assert weight & optim state
# assert weight # ##########################
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
##########################
# assert optim state
##########################
optim_base_state = optimizer_base.state_dict()["state"] optim_base_state = optimizer_base.state_dict()["state"]
optim_pp_state = optimizer_pp.state_dict()["state"] optim_pp_state = optimizer_pp.state_dict()["state"]
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
# if rank == 0:
# print(f"optim_base_state {optim_base_state}")
# assert param group if rank == 0:
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): # layer 0
if key_base == key_pp: assert_close(model_pp.layers[0].weight, model_base.layers[0].weight)
if key_base != "params": assert_close(model_pp.layers[0].weight.grad, model_base.layers[0].weight.grad)
assert val_base == val_pp assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[0]["momentum_buffer"])
else: # layer 7
# BUG: assert_close(model_pp.layers[7].weight, model_base.layers[7].weight)
# param_base: [0, 1, 2, 3, 4, 5, 6, 7]; assert_close(model_pp.layers[7].weight.grad, model_base.layers[7].weight.grad)
# params pp: [0, 1]; assert_close(optim_pp_state[7]["momentum_buffer"], optim_base_state[7]["momentum_buffer"])
assert val_base[:2] == val_pp if rank == 1:
# layer 1
assert_close(model_pp.layers[1].weight, model_base.layers[1].weight)
assert_close(model_pp.layers[1].weight.grad, model_base.layers[1].weight.grad)
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[1]["momentum_buffer"])
# layer 6
assert_close(model_pp.layers[6].weight, model_base.layers[6].weight)
assert_close(model_pp.layers[6].weight.grad, model_base.layers[6].weight.grad)
assert_close(optim_pp_state[6]["momentum_buffer"], optim_base_state[6]["momentum_buffer"])
if rank == 2:
# layer 2
assert_close(model_pp.layers[2].weight, model_base.layers[2].weight)
assert_close(model_pp.layers[2].weight.grad, model_base.layers[2].weight.grad)
assert_close(optim_pp_state[2]["momentum_buffer"], optim_base_state[2]["momentum_buffer"])
# layer 5
assert_close(model_pp.layers[5].weight, model_base.layers[5].weight)
assert_close(model_pp.layers[5].weight.grad, model_base.layers[5].weight.grad)
assert_close(optim_pp_state[5]["momentum_buffer"], optim_base_state[5]["momentum_buffer"])
if rank == 3:
# layer 3
assert_close(model_pp.layers[3].weight, model_base.layers[3].weight)
assert_close(model_pp.layers[3].weight.grad, model_base.layers[3].weight.grad)
assert_close(optim_pp_state[3]["momentum_buffer"], optim_base_state[3]["momentum_buffer"])
# layer 4
assert_close(model_pp.layers[4].weight, model_base.layers[4].weight)
assert_close(model_pp.layers[4].weight.grad, model_base.layers[4].weight.grad)
assert_close(optim_pp_state[4]["momentum_buffer"], optim_base_state[4]["momentum_buffer"])
# assert state # assert optim param_groups
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"]) assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
# TODO:4) support Hybrid base 3) # TODO:4) support Hybrid base 3)