From fafe049b83bad3a6aa6e3a31c68b38ac63167b53 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 29 Oct 2024 03:24:15 +0000 Subject: [PATCH] [fix] fix handle name; rm useless comments; --- .../pipeline/schedule/zero_bubble_pp.py | 7 ++- colossalai/pipeline/weight_grad_store.py | 51 ------------------- colossalai/shardformer/policies/llama.py | 20 +------- .../test_schedule/test_zerobubble_pp.py | 4 -- 4 files changed, 4 insertions(+), 78 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index c22dce7da..638b601d4 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -107,7 +107,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.local_send_backward_buffer = [] # wait pp buffer - self.send_handles = [] + self.wait_handles = [] def assert_buffer_empty(self): # assert buffer is empty at end @@ -129,7 +129,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): assert len(self.recv_backward_buffer[1]) == 0 assert len(self.local_send_forward_buffer) == 0 assert len(self.local_send_backward_buffer) == 0 - # assert len(self.send_handles) == 0 def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -891,7 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # communication communication_func = self.communication_map[scheduled_node.type] wait_handle = communication_func(scheduled_node.chunk) - self.send_handles.append(wait_handle) + self.wait_handles.append(wait_handle) elif scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, @@ -915,7 +914,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) - for h in self.send_handles: + for h in self.wait_handles: for hh in h: hh.wait() diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index dff4fdd02..c51c45085 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -1,7 +1,5 @@ import queue -from colossalai.pipeline.stage_manager import PipelineStageManager - class WeightGradStore: @@ -32,52 +30,3 @@ class WeightGradStore: weight.grad = grad_weight else: raise Exception("Pop empty queue.") - - @classmethod - def clear(cls, stage_manager: PipelineStageManager, chunk=0): - pass - # print(f"stage {stage_manager.stage} len_chunk_0 {cls.weight_grad_queue[0].qsize()} len_chunk_1 {cls.weight_grad_queue[1].qsize()}") - # while cls.weight_grad_queue[chunk].qsize() > 0: - # stored_grads = cls.weight_grad_queue[chunk].get() - # for total_input, grad_output, weight, func in stored_grads: - # if weight.grad is not None: - # func(total_input, grad_output, weight.grad) - # # for first bwd; weight.grad is None, assign grad_weight to weight.grad - # else: - # grad_weight = func(total_input, grad_output) - # weight.grad = grad_weight - - # weight_grad_tasks = [] - # while cls.weight_grad_queue[chunk].qsize() > 0: - # stored_grads = cls.weight_grad_queue[chunk].get() - # if len(weight_grad_tasks) == 0: - # for _ in stored_grads: - # weight_grad_tasks.append([]) - # else: - # assert len(weight_grad_tasks) == len(stored_grads) - # for i, task in enumerate(stored_grads): - # weight_grad_tasks[i].append(task) - - # if stage_manager.is_last_stage(ignore_chunk=True) and chunk == 1: - # assert len(weight_grad_tasks) > 0 - # output_layer_grads = weight_grad_tasks[0] - # for j in range(len(output_layer_grads)): - # total_input, grad_output, weight, func = output_layer_grads[j] - # if output_layer_weight is None: - # output_layer_weight = weight - # assert output_layer_weight is weight - # func(total_input, grad_output, weight.grad) - # output_layer_grads[j] = None # release memory - # weight_grad_tasks = weight_grad_tasks[1:] - - # for i in range(len(weight_grad_tasks)): - # tasks = weight_grad_tasks[i] - # param = None - # for j in range(len(tasks)): - # total_input, grad_output, weight, func = tasks[j] - # if param is None: - # param = weight - # assert param is weight - # func(total_input, grad_output, weight.grad) - # tasks[j] = None # release memory - # weight_grad_tasks[i] = None # release memory diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index bef39a6ca..756d32454 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -60,10 +60,7 @@ class LlamaPolicy(Policy): else: norm_cls = RMSNorm - if self.pipeline_stage_manager: - use_zbv = self.pipeline_stage_manager.use_zbv - else: - use_zbv = False + use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None @@ -96,7 +93,6 @@ class LlamaPolicy(Policy): target_key=attn_cls, ) - # if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is None: self.append_or_create_method_replacement( description={ @@ -410,20 +406,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy): self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, } ] - # if self.pipeline_stage_manager.use_zbv: - # return [ - # { - # 0: llama_model.embed_tokens.weight, - # 0: self.model.lm_head.weight, - # } - # ] - # else: - # return [ - # { - # 0: llama_model.embed_tokens.weight, - # self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, - # } - # ] return [] diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index c485d3f54..71ff11059 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -904,7 +904,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output.backward() torch_output_sum += torch_output.detach() - # print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}") # avg dp grads follows zero optimizer for p in torch_model.parameters(): if p.grad is not None: @@ -912,7 +911,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() - # print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}") assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} config {test_config} test passed") clear_layout_converter() @@ -1064,7 +1062,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output.backward() torch_output_sum += torch_output.detach() - # print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}") # avg dp grads follows zero optimizer for p in torch_model.parameters(): if p.grad is not None: @@ -1072,7 +1069,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): torch_optimizer.step() torch_optimizer.zero_grad() - # print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}") assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") clear_layout_converter()