From fc8b016887e48b03a52e789770d295a8a9842943 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Sep 2024 06:15:45 +0000 Subject: [PATCH 1/8] [fix] fix stage_indices; --- .../pipeline/schedule/zero_bubble_pp.py | 26 ++++++++++++------- .../test_schedule/test_zerobubble_pp.py | 3 +++ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index bbad921b2..307d1035c 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -430,7 +430,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): with self.stage_manager.switch_model_chunk_id(model_chunk_id): # fwd calculate internal_inputs = {} if input_obj is None else input_obj - # internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) # last layer in model @@ -480,22 +480,26 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # For chunk 0 stage 0, use micro_batch as input_obj_ if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj_, _ = tree_flatten(micro_batch) - output_obj_, _ = tree_flatten(output_obj) # y - output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + input_obj_, _ = tree_flatten({k: v for k, v in micro_batch.items() if isinstance(v, torch.Tensor)}) + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) # dy # For loss backward; output_obj is loss; output_obj_grad should be None elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): assert output_obj_grad is None - input_obj_, _ = tree_flatten(input_obj) + input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) output_obj_.append(output_obj) # LOSS output_obj_grad_.append(output_obj_grad) # None # For other chunk stage, use input_obj as input_obj_; else: - input_obj_, _ = tree_flatten(input_obj) - output_obj_, _ = tree_flatten(output_obj) # y - output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) # dy optimizer.backward_by_grad( tensor=output_obj_, @@ -547,8 +551,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_.append(output_obj) # LOSS output_obj_grad_.append(None) # None else: - output_obj_, _ = tree_flatten(output_obj) # y - output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) # dy optimizer.backward_by_grad( tensor=output_obj_, diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 14bc3475d..9fa636504 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -39,6 +39,7 @@ def pp_linear_fwd( forward, data: torch.Tensor = None, hidden_states: torch.Tensor = None, + stage_index=None, stage_mgr: PipelineStageManager = None, model_chunk_id: int = None, ): @@ -605,6 +606,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # input_base = [t.clone() for t in data_iter] input_base = {k: v.clone() for k, v in data_iter.items()} model_base = deepcopy(model) + layers_per_stage = stage_manager.distribute_layers(len(model.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) if rank == 0: # layer 0 & 7 to chunk 0 on rank0 From 83163fa70c49085b15ea063fcb0ee188d28f4871 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Sep 2024 06:38:11 +0000 Subject: [PATCH 2/8] [fix] fix traverse; traverse dict --> traverse tensor List; --- .../pipeline/schedule/zero_bubble_pp.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 307d1035c..0272cc113 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -480,26 +480,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # For chunk 0 stage 0, use micro_batch as input_obj_ if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj_, _ = tree_flatten({k: v for k, v in micro_batch.items() if isinstance(v, torch.Tensor)}) - output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y - output_obj_grad_, _ = tree_flatten( - {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} - ) # dy + input_obj_, _ = tree_flatten(micro_batch) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy # For loss backward; output_obj is loss; output_obj_grad should be None elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): assert output_obj_grad is None - input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) + input_obj_, _ = tree_flatten(input_obj) output_obj_.append(output_obj) # LOSS output_obj_grad_.append(output_obj_grad) # None # For other chunk stage, use input_obj as input_obj_; else: - input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) - output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y - output_obj_grad_, _ = tree_flatten( - {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} - ) # dy + input_obj_, _ = tree_flatten(input_obj) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # filter item which is not torch.Tensor + input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] optimizer.backward_by_grad( tensor=output_obj_, @@ -551,10 +552,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_.append(output_obj) # LOSS output_obj_grad_.append(None) # None else: - output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y - output_obj_grad_, _ = tree_flatten( - {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} - ) # dy + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # filter item which is not torch.Tensor + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] optimizer.backward_by_grad( tensor=output_obj_, From a92e16719b870b264c6e3447931a717b648102fa Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 26 Sep 2024 06:11:56 +0000 Subject: [PATCH 3/8] [fix] fix zerobubble; support shardformer model type; --- .../pipeline/schedule/zero_bubble_pp.py | 4 +- colossalai/pipeline/stage_manager.py | 12 + .../test_schedule/test_zerobubble_pp.py | 222 ++++++++---------- 3 files changed, 109 insertions(+), 129 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 0272cc113..66fbc827b 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -431,7 +431,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # fwd calculate internal_inputs = {} if input_obj is None else input_obj internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] - output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) # last layer in model if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): @@ -562,7 +562,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad_, - inputs=list(model_chunk[model_chunk_id].parameters()), + inputs=list(model_chunk.parameters()), retain_graph=False, ) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 354f110f0..50cc965bb 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -26,6 +26,7 @@ class PipelineStageManager: pg_mesh: ProcessGroupMesh, pipeline_axis: int, enable_interleave: bool = False, + use_zbv: bool = False, num_model_chunks: int = 1, num_layers_per_stage: Optional[List[int]] = None, ) -> None: @@ -49,6 +50,7 @@ class PipelineStageManager: next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") self.is_interleave = enable_interleave + self.use_zbv = use_zbv # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers self.num_model_chunks: int = num_model_chunks # for shardformer, hold stage indices of model @@ -85,6 +87,16 @@ class PipelineStageManager: num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) stage_indices = [] + if self.use_zbv: + stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]]) + stage_indices.append( + [ + num_layers_per_stage_accumulated[2 * num_stages - stage - 1], + num_layers_per_stage_accumulated[2 * num_stages - stage], + ] + ) + return stage_indices + for model_chunk in range(num_model_chunks): start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 9fa636504..ccef295d4 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,6 +1,5 @@ from copy import deepcopy from functools import partial -from types import MethodType from typing import Tuple import pytest @@ -22,37 +21,54 @@ from tests.kit.model_zoo import model_zoo class MlpModel(nn.Module): - def __init__(self, in_dim, out_dim, num_layers): + def __init__( + self, + in_dim, + out_dim, + num_layers, + stage_index=None, + stage_mgr: PipelineStageManager = None, + ): super().__init__() - self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + self.layers = nn.Sequential(*[nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + # self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) + # if stage_mgr: + # self.held_layers = self.layers[stage_index[0]:stage_index[1]] def forward( self, - hidden_states, + model=None, + data: torch.Tensor = None, + hidden_states: torch.Tensor = None, + stage_index=None, + stage_mgr: PipelineStageManager = None, + model_chunk_id: int = None, ): - for layer in self.layers: - hidden_states = layer(hidden_states) - return hidden_states - - -def pp_linear_fwd( - forward, - data: torch.Tensor = None, - hidden_states: torch.Tensor = None, - stage_index=None, - stage_mgr: PipelineStageManager = None, - model_chunk_id: int = None, -): - with stage_mgr.switch_model_chunk_id(model_chunk_id): - # fwd end - if stage_mgr.is_first_stage() and model_chunk_id == 1: - return forward(hidden_states) - # fwd start - elif stage_mgr.is_first_stage() and model_chunk_id == 0: - return {"hidden_states": forward(data)} - # fwd middle + if stage_mgr is None: + hidden_states = data + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states else: - return {"hidden_states": forward(hidden_states)} + # Set not used layer to None + held_layers = self.layers[stage_index[0] : stage_index[1]] + + # fwd end + if stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 1: + return held_layers(hidden_states) + # fwd start + elif stage_mgr.is_first_stage() and stage_mgr.model_chunk_id == 0: + return {"hidden_states": held_layers(data)} + # fwd middle + else: + return {"hidden_states": held_layers(hidden_states)} + + +def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups): + for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): + if key_base == key_pp: + if key_base != "params": + assert val_base == val_pp def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: @@ -555,7 +571,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): num_model_chunk = test_config["num_model_chunk"] # stage_manager stage_manager = PipelineStageManager( - pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk, use_zbv=True ) h, a, s = 4096, 32, 1024 @@ -601,69 +617,30 @@ def run_fwd_bwd_vschedule_with_optim(test_config): before_init_memory = torch.cuda.memory_allocated() / 1024**3 print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - # data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} - # input_base = [t.clone() for t in data_iter] input_base = {k: v.clone() for k, v in data_iter.items()} model_base = deepcopy(model) + model_pp = deepcopy(model) layers_per_stage = stage_manager.distribute_layers(len(model.layers)) stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) - if rank == 0: - # layer 0 & 7 to chunk 0 on rank0 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 0 or idx == 7: - sub_model._forward = sub_model.forward - sub_model.forward = MethodType( - partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), - sub_model._forward, - ) - local_chunk.append(sub_model) - elif rank == 1: - # layer 1 & 6 to chunk 1 on rank1 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 1 or idx == 6: - sub_model._forward = sub_model.forward - sub_model.forward = MethodType( - partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), - sub_model._forward, - ) - local_chunk.append(sub_model) - elif rank == 2: - # layer 2 & 5 to chunk 2 on rank2 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 2 or idx == 5: - sub_model._forward = sub_model.forward - sub_model.forward = MethodType( - partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), - sub_model._forward, - ) - local_chunk.append(sub_model) - else: - # layer 3 & 4 to chunk 3 on rank3 - local_chunk = torch.nn.ModuleList().to(rank) - for idx, sub_model in enumerate(model.layers): - if idx == 3 or idx == 4: - sub_model._forward = sub_model.forward - sub_model.forward = MethodType( - partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), - sub_model._forward, - ) - local_chunk.append(sub_model) + model_pp._forward = model_pp.forward + # model_pp.forward = MethodType( + # partial(model_pp._forward, stage_mgr=stage_manager), + # model_pp, + # ) + model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager) # init optimizer optimizer_base = torch.optim.SGD(model_base.parameters(), momentum=0.1, lr=1e-5) - optimizer_pp = OptimizerWrapper(torch.optim.SGD(local_chunk.parameters(), momentum=0.1, lr=1e-5)) + optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5)) after_init_memory = torch.cuda.memory_allocated() / 1024**3 print(f"After init Model & input: {after_init_memory :.5f} GB on device {stage_manager.get_rank()};") torch.cuda.synchronize() result = scheduler.forward_backward_step( - model_chunk=local_chunk, + model_chunk=model_pp, data_iter=iter([data_iter]), criterion=criterion, optimizer=optimizer_pp, @@ -697,7 +674,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # Fwd bwd for base ########################## # fwd & bwd - output_base = model_base(input_base["data"]) + # output_base = model_base(input_base["data"]) + output_base = model_base.forward(data=input_base["data"]) loss_base = criterion_base(output_base) loss_base.backward() optimizer_base.step() @@ -710,63 +688,53 @@ def run_fwd_bwd_vschedule_with_optim(test_config): assert_close(result["loss"], loss_base) assert_close(result["outputs"]["hidden_states"], output_base) - # print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") - ########################## - # assert weight - ########################## - if rank == 0: - # layer 0 - assert_close(local_chunk[0].weight, model_base.layers[0].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[0].weight.grad) - # layer 7 - assert_close(local_chunk[1].weight, model_base.layers[7].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[7].weight.grad) - if rank == 1: - # layer 1 - assert_close(local_chunk[0].weight, model_base.layers[1].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[1].weight.grad) - # layer 6 - assert_close(local_chunk[1].weight, model_base.layers[6].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) - if rank == 2: - # layer 2 - assert_close(local_chunk[0].weight, model_base.layers[2].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[2].weight.grad) - # layer 5 - assert_close(local_chunk[1].weight, model_base.layers[5].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) - if rank == 3: - # layer 3 - assert_close(local_chunk[0].weight, model_base.layers[3].weight) - assert_close(local_chunk[0].weight.grad, model_base.layers[3].weight.grad) - # layer 4 - assert_close(local_chunk[1].weight, model_base.layers[4].weight) - assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) - - ########################## - # assert optim state - ########################## + # ########################## + # # assert weight & optim state + # ########################## optim_base_state = optimizer_base.state_dict()["state"] optim_pp_state = optimizer_pp.state_dict()["state"] optim_base_param_groups = optimizer_base.state_dict()["param_groups"][0] optim_pp_param_groups = optimizer_pp.state_dict()["param_groups"][0] - # if rank == 0: - # print(f"optim_base_state {optim_base_state}") - # assert param group - for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()): - if key_base == key_pp: - if key_base != "params": - assert val_base == val_pp - else: - # BUG: - # param_base: [0, 1, 2, 3, 4, 5, 6, 7]; - # params pp: [0, 1]; - assert val_base[:2] == val_pp + if rank == 0: + # layer 0 + assert_close(model_pp.layers[0].weight, model_base.layers[0].weight) + assert_close(model_pp.layers[0].weight.grad, model_base.layers[0].weight.grad) + assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[0]["momentum_buffer"]) + # layer 7 + assert_close(model_pp.layers[7].weight, model_base.layers[7].weight) + assert_close(model_pp.layers[7].weight.grad, model_base.layers[7].weight.grad) + assert_close(optim_pp_state[7]["momentum_buffer"], optim_base_state[7]["momentum_buffer"]) + if rank == 1: + # layer 1 + assert_close(model_pp.layers[1].weight, model_base.layers[1].weight) + assert_close(model_pp.layers[1].weight.grad, model_base.layers[1].weight.grad) + assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[1]["momentum_buffer"]) + # layer 6 + assert_close(model_pp.layers[6].weight, model_base.layers[6].weight) + assert_close(model_pp.layers[6].weight.grad, model_base.layers[6].weight.grad) + assert_close(optim_pp_state[6]["momentum_buffer"], optim_base_state[6]["momentum_buffer"]) + if rank == 2: + # layer 2 + assert_close(model_pp.layers[2].weight, model_base.layers[2].weight) + assert_close(model_pp.layers[2].weight.grad, model_base.layers[2].weight.grad) + assert_close(optim_pp_state[2]["momentum_buffer"], optim_base_state[2]["momentum_buffer"]) + # layer 5 + assert_close(model_pp.layers[5].weight, model_base.layers[5].weight) + assert_close(model_pp.layers[5].weight.grad, model_base.layers[5].weight.grad) + assert_close(optim_pp_state[5]["momentum_buffer"], optim_base_state[5]["momentum_buffer"]) + if rank == 3: + # layer 3 + assert_close(model_pp.layers[3].weight, model_base.layers[3].weight) + assert_close(model_pp.layers[3].weight.grad, model_base.layers[3].weight.grad) + assert_close(optim_pp_state[3]["momentum_buffer"], optim_base_state[3]["momentum_buffer"]) + # layer 4 + assert_close(model_pp.layers[4].weight, model_base.layers[4].weight) + assert_close(model_pp.layers[4].weight.grad, model_base.layers[4].weight.grad) + assert_close(optim_pp_state[4]["momentum_buffer"], optim_base_state[4]["momentum_buffer"]) - # assert state - assert_close(optim_pp_state[0]["momentum_buffer"], optim_base_state[2 * rank]["momentum_buffer"]) - assert_close(optim_pp_state[1]["momentum_buffer"], optim_base_state[2 * rank + 1]["momentum_buffer"]) + # assert optim param_groups + assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups) # TODO:4) support Hybrid base 3) From 45f17fc6ccb239b64c010831ddbcabe32984e4f3 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 26 Sep 2024 06:13:56 +0000 Subject: [PATCH 4/8] [fix] rm comments; --- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ccef295d4..46bd4a581 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -31,9 +31,6 @@ class MlpModel(nn.Module): ): super().__init__() self.layers = nn.Sequential(*[nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) - # self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) - # if stage_mgr: - # self.held_layers = self.layers[stage_index[0]:stage_index[1]] def forward( self, From c5503b0d8063b598ad0410b13afc9ecaf1c0e48b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 26 Sep 2024 07:18:16 +0000 Subject: [PATCH 5/8] [fix] fix test_pipeline_utils ci; --- .../test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py | 1 + .../test_pipeline_utils/test_whisper_pipeline_utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index e2f71ff89..f79bdeb3a 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index d39c5ea91..722b8fd7c 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -15,6 +15,7 @@ class _PipelineStageManager(PipelineStageManager): self.is_interleave = False self.num_layers_per_stage = None self.num_model_chunks = 1 + self.use_zbv = False @property def num_stages(self): From bb0390c90d8645b2d58035e82335049c468d36ec Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 26 Sep 2024 09:45:44 +0000 Subject: [PATCH 6/8] [fix] remove duplicate arg; rm comments; --- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 46bd4a581..0f2d6c49c 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -34,7 +34,6 @@ class MlpModel(nn.Module): def forward( self, - model=None, data: torch.Tensor = None, hidden_states: torch.Tensor = None, stage_index=None, @@ -622,10 +621,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) model_pp._forward = model_pp.forward - # model_pp.forward = MethodType( - # partial(model_pp._forward, stage_mgr=stage_manager), - # model_pp, - # ) + model_pp.forward = partial(model_pp._forward, stage_mgr=stage_manager) # init optimizer From 64ceea746f5f5463504d5f56c0d69c088f221d5f Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 26 Sep 2024 10:50:44 +0000 Subject: [PATCH 7/8] [fix] remove chunk 0 stage 0 bwd b; u don't have to cal micrbatch's dx; --- .../pipeline/schedule/zero_bubble_pp.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 66fbc827b..8562d23f2 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -449,7 +449,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, - micro_batch: Optional[dict], + # micro_batch: Optional[dict], input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], @@ -480,9 +480,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # For chunk 0 stage 0, use micro_batch as input_obj_ if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj_, _ = tree_flatten(micro_batch) - output_obj_, _ = tree_flatten(output_obj) # y - output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + # input_obj_, _ = tree_flatten(micro_batch) + # output_obj_, _ = tree_flatten(output_obj) # y + # output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + return None # For loss backward; output_obj is loss; output_obj_grad should be None elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): @@ -512,9 +513,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Format output_obj_grad input_obj_grad = {} if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - for k, v in micro_batch.items(): - if isinstance(v, torch.Tensor) and v.grad is not None: - input_obj_grad[k] = v.grad + # for k, v in micro_batch.items(): + # if isinstance(v, torch.Tensor) and v.grad is not None: + # input_obj_grad[k] = v.grad + pass else: for k, v in input_obj.items(): if isinstance(v, torch.Tensor) and v.grad is not None: @@ -643,7 +645,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): tree_map(release_tensor_data, output_obj) # add input and output object for backward b - self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) + # self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) + self.input_tensors[model_chunk_id].append(input_obj) # for bwd b&w, we only need the graph(grad_fn) of output_obj # Do not release_tensor_data loss, release_tensor_data other output_obj; @@ -701,7 +704,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # get input and output object from buffer; - micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) + # micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) + input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) # save output_tensor_grad for dw @@ -717,7 +721,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk=model_chunk, model_chunk_id=model_chunk_id, optimizer=optimizer, - micro_batch=micro_batch, input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_tensor_grad, @@ -838,6 +841,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # while we still have schedules_node in self.schedules schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) + print(f"schedule {schedule}") for it in range(len(schedule)): scheduled_node = schedule[it] if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: From 1342a983b10a1d44632fce5545e3a1a107687082 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 26 Sep 2024 11:05:27 +0000 Subject: [PATCH 8/8] [fix] rm print & comments; --- colossalai/pipeline/schedule/zero_bubble_pp.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 8562d23f2..5c25c5bfa 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -478,11 +478,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_ = [] output_obj_grad_ = [] - # For chunk 0 stage 0, use micro_batch as input_obj_ + # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - # input_obj_, _ = tree_flatten(micro_batch) - # output_obj_, _ = tree_flatten(output_obj) # y - # output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy return None # For loss backward; output_obj is loss; output_obj_grad should be None @@ -513,9 +510,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Format output_obj_grad input_obj_grad = {} if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - # for k, v in micro_batch.items(): - # if isinstance(v, torch.Tensor) and v.grad is not None: - # input_obj_grad[k] = v.grad pass else: for k, v in input_obj.items(): @@ -645,7 +639,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): tree_map(release_tensor_data, output_obj) # add input and output object for backward b - # self.input_tensors[model_chunk_id].append((micro_batch, input_obj)) self.input_tensors[model_chunk_id].append(input_obj) # for bwd b&w, we only need the graph(grad_fn) of output_obj @@ -704,7 +697,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # get input and output object from buffer; - # micro_batch, input_obj = self.input_tensors[model_chunk_id].pop(0) input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) @@ -841,7 +833,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # while we still have schedules_node in self.schedules schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) - print(f"schedule {schedule}") for it in range(len(schedule)): scheduled_node = schedule[it] if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: