[fix] fix zerobubble; support shardformer model type;

pull/6069/head
duanjunwen 2024-09-26 06:11:56 +00:00
parent 83163fa70c
commit a92e16719b
3 changed files with 109 additions and 129 deletions

View File

@ -431,7 +431,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# fwd calculate # fwd calculate
internal_inputs = {} if input_obj is None else input_obj internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model # last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
@ -562,7 +562,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
optimizer.backward_by_grad( optimizer.backward_by_grad(
tensor=output_obj_, tensor=output_obj_,
grad=output_obj_grad_, grad=output_obj_grad_,
inputs=list(model_chunk[model_chunk_id].parameters()), inputs=list(model_chunk.parameters()),
retain_graph=False, retain_graph=False,
) )

View File

@ -26,6 +26,7 @@ class PipelineStageManager:
pg_mesh: ProcessGroupMesh, pg_mesh: ProcessGroupMesh,
pipeline_axis: int, pipeline_axis: int,
enable_interleave: bool = False, enable_interleave: bool = False,
use_zbv: bool = False,
num_model_chunks: int = 1, num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None, num_layers_per_stage: Optional[List[int]] = None,
) -> None: ) -> None:
@ -49,6 +50,7 @@ class PipelineStageManager:
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
self.is_interleave = enable_interleave self.is_interleave = enable_interleave
self.use_zbv = use_zbv
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model # for shardformer, hold stage indices of model
@ -85,6 +87,16 @@ class PipelineStageManager:
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
stage_indices = [] stage_indices = []
if self.use_zbv:
stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]])
stage_indices.append(
[
num_layers_per_stage_accumulated[2 * num_stages - stage - 1],
num_layers_per_stage_accumulated[2 * num_stages - stage],
]
)
return stage_indices
for model_chunk in range(num_model_chunks): for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]

View File

@ -1,6 +1,5 @@
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from types import MethodType
from typing import Tuple from typing import Tuple
import pytest import pytest
@ -22,37 +21,54 @@ from tests.kit.model_zoo import model_zoo
class MlpModel(nn.Module): class MlpModel(nn.Module):
def __init__(self, in_dim, out_dim, num_layers): def __init__(
self,
in_dim,
out_dim,
num_layers,
stage_index=None,
stage_mgr: PipelineStageManager = None,
):
super().__init__() super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) self.layers = nn.Sequential(*[nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
# self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
# if stage_mgr:
# self.held_layers = self.layers[stage_index[0]:stage_index[1]]
def forward( def forward(
self, self,
hidden_states, model=None,
data: torch.Tensor = None,
hidden_states: torch.Tensor = None,
stage_index=None,
stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None,
): ):
for layer in self.layers: if stage_mgr is None:
hidden_states = layer(hidden_states) hidden_states = data
return hidden_states for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
def pp_linear_fwd(
forward,
data: torch.Tensor = None,
hidden_states: torch.Tensor = None,
stage_index=None,
stage_mgr: PipelineStageManager = None,
model_chunk_id: int = None,
):
with stage_mgr.switch_model_chunk_id(model_chunk_id):
# fwd end
if stage_mgr.is_first_stage() and model_chunk_id == 1:
return forward(hidden_states)
# fwd start
elif stage_mgr.is_first_stage() and model_chunk_id == 0:
return {"hidden_states": forward(data)}
# fwd middle
else: else:
return {"hidden_states": forward(hidden_states)} # Set not used layer to None
held_layers = self.layers[stage_index[0] : stage_index[1]]
# fwd end
if stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 1:
return held_layers(hidden_states)
# fwd start
elif stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 0:
return {"hidden_states": held_layers(data)}
# fwd middle
else:
return {"hidden_states": held_layers(hidden_states)}
def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups):
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
if key_base == key_pp:
if key_base != "params":
assert val_base == val_pp
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
@ -555,7 +571,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
num_model_chunk = test_config["num_model_chunk"] num_model_chunk = test_config["num_model_chunk"]
# stage_manager # stage_manager
stage_manager = PipelineStageManager( stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk, use_zbv=True
) )
h, a, s = 4096, 32, 1024 h, a, s = 4096, 32, 1024
@ -601,69 +617,30 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
before_init_memory = torch.cuda.memory_allocated() / 1024**3 before_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
# data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)]
data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)}
# input_base = [t.clone() for t in data_iter]
input_base = {k: v.clone() for k, v in data_iter.items()} input_base = {k: v.clone() for k, v in data_iter.items()}
model_base = deepcopy(model) model_base = deepcopy(model)
model_pp = deepcopy(model)
layers_per_stage = stage_manager.distribute_layers(len(model.layers)) layers_per_stage = stage_manager.distribute_layers(len(model.layers))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
if rank == 0: model_pp._forward = model_pp.forward
# layer 0 & 7 to chunk 0 on rank0 # model_pp.forward = MethodType(
local_chunk = torch.nn.ModuleList().to(rank) # partial(model_pp._forward, stage_mgr=stage_manager),
for idx, sub_model in enumerate(model.layers): # model_pp,
if idx == 0 or idx == 7: # )
sub_model._forward = sub_model.forward model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager)
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 1:
# layer 1 & 6 to chunk 1 on rank1
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 1 or idx == 6:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
elif rank == 2:
# layer 2 & 5 to chunk 2 on rank2
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 2 or idx == 5:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
else:
# layer 3 & 4 to chunk 3 on rank3
local_chunk = torch.nn.ModuleList().to(rank)
for idx, sub_model in enumerate(model.layers):
if idx == 3 or idx == 4:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)),
sub_model._forward,
)
local_chunk.append(sub_model)
# init optimizer # init optimizer
optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5)
optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5)) optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
after_init_memory = torch.cuda.memory_allocated() / 1024**3 after_init_memory = torch.cuda.memory_allocated() / 1024**3
print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};")
torch.cuda.synchronize() torch.cuda.synchronize()
result = scheduler.forward_backward_step( result = scheduler.forward_backward_step(
model_chunk=local_chunk, model_chunk=model_pp,
data_iter=iter([data_iter]), data_iter=iter([data_iter]),
criterion=criterion, criterion=criterion,
optimizer=optimizer_pp, optimizer=optimizer_pp,
@ -697,7 +674,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
# Fwd bwd for base # Fwd bwd for base
########################## ##########################
# fwd & bwd # fwd & bwd
output_base = model_base(input_base["data"]) # output_base = model_base(input_base["data"])
output_base = model_base.forward(data=input_base["data"])
loss_base = criterion_base(output_base) loss_base = criterion_base(output_base)
loss_base.backward() loss_base.backward()
optimizer_base.step() optimizer_base.step()
@ -710,63 +688,53 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
assert_close(result["loss"], loss_base) assert_close(result["loss"], loss_base)
assert_close(result["outputs"]["hidden_states"], output_base) assert_close(result["outputs"]["hidden_states"], output_base)
# print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") # ##########################
########################## # # assert weight & optim state
# assert weight # ##########################
##########################
if rank == 0:
# layer 0
assert_close(local_chunk[0].weight, model_base.layers[0].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad)
# layer 7
assert_close(local_chunk[1].weight, model_base.layers[7].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad)
if rank == 1:
# layer 1
assert_close(local_chunk[0].weight, model_base.layers[1].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad)
# layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2:
# layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad)
# layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3:
# layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight)
assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad)
# layer 4
assert_close(local_chunk[1].weight, model_base.layers[4].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
##########################
# assert optim state
##########################
optim_base_state = optimizer_base.state_dict()["state"] optim_base_state = optimizer_base.state_dict()["state"]
optim_pp_state = optimizer_pp.state_dict()["state"] optim_pp_state = optimizer_pp.state_dict()["state"]
optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0]
optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0]
# if rank == 0:
# print(f"optim_base_state {optim_base_state}")
# assert param group if rank == 0:
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): # layer 0
if key_base == key_pp: assert_close(model_pp.layers[0].weight, model_base.layers[0].weight)
if key_base != "params": assert_close(model_pp.layers[0].weight.grad, model_base.layers[0].weight.grad)
assert val_base == val_pp assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[0]["momentum_buffer"])
else: # layer 7
# BUG: assert_close(model_pp.layers[7].weight, model_base.layers[7].weight)
# param_base: [0, 1, 2, 3, 4, 5, 6, 7]; assert_close(model_pp.layers[7].weight.grad, model_base.layers[7].weight.grad)
# params pp: [0, 1]; assert_close(optim_pp_state[7]["momentum_buffer"], optim_base_state[7]["momentum_buffer"])
assert val_base[:2] == val_pp if rank == 1:
# layer 1
assert_close(model_pp.layers[1].weight, model_base.layers[1].weight)
assert_close(model_pp.layers[1].weight.grad, model_base.layers[1].weight.grad)
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[1]["momentum_buffer"])
# layer 6
assert_close(model_pp.layers[6].weight, model_base.layers[6].weight)
assert_close(model_pp.layers[6].weight.grad, model_base.layers[6].weight.grad)
assert_close(optim_pp_state[6]["momentum_buffer"], optim_base_state[6]["momentum_buffer"])
if rank == 2:
# layer 2
assert_close(model_pp.layers[2].weight, model_base.layers[2].weight)
assert_close(model_pp.layers[2].weight.grad, model_base.layers[2].weight.grad)
assert_close(optim_pp_state[2]["momentum_buffer"], optim_base_state[2]["momentum_buffer"])
# layer 5
assert_close(model_pp.layers[5].weight, model_base.layers[5].weight)
assert_close(model_pp.layers[5].weight.grad, model_base.layers[5].weight.grad)
assert_close(optim_pp_state[5]["momentum_buffer"], optim_base_state[5]["momentum_buffer"])
if rank == 3:
# layer 3
assert_close(model_pp.layers[3].weight, model_base.layers[3].weight)
assert_close(model_pp.layers[3].weight.grad, model_base.layers[3].weight.grad)
assert_close(optim_pp_state[3]["momentum_buffer"], optim_base_state[3]["momentum_buffer"])
# layer 4
assert_close(model_pp.layers[4].weight, model_base.layers[4].weight)
assert_close(model_pp.layers[4].weight.grad, model_base.layers[4].weight.grad)
assert_close(optim_pp_state[4]["momentum_buffer"], optim_base_state[4]["momentum_buffer"])
# assert state # assert optim param_groups
assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"]) assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups)
assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"])
# TODO:4) support Hybrid base 3) # TODO:4) support Hybrid base 3)