From fef5c949c35b1f1e0075a9e4abb23a5ec0f48e3c Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Fri, 13 Jan 2023 16:56:01 +0800 Subject: [PATCH] polish pp middleware (#2476) Co-authored-by: Ziyue Jiang --- colossalai/pipeline/rpc/_pipeline_base.py | 4 ++-- colossalai/pipeline/rpc/_pipeline_schedule.py | 3 --- .../gpt/experiments/pipeline_parallel/train_gpt_pp.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 4739cdaa9..1edc1ac70 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -211,7 +211,7 @@ class WorkerBase(ABC): refcount = 0 with self.output_list_condition_lock: - if refcount < lifecycle: + if refcount <= lifecycle: self.output_list[key] = output_work_item self.output_list_condition_lock.notify_all() @@ -390,7 +390,7 @@ class WorkerBase(ABC): subscribe_forward_futures[target_index] = [] else: subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( - producer_output_key, rank=self.pp_rank) + producer_output_key, rank=self.pp_rank, offsets=offsets) else: for i in range(producer_num): diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py index e6aa961f1..0d572231d 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -29,9 +29,6 @@ class FillDrainWorker(WorkerBase): target_key = UniqueKey(target_microbatch_id, target_phase) - with self.work_list_condition_lock: - self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list) - return target_key diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py index 79efa61b0..c3451c18d 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -120,7 +120,7 @@ def run_master(args): logger.info(f'{rank=} numel in the partition:{numel}') # build optim - pp_engine.initialize_optimizer(HybridAdam, lr=1e-3) + pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) ranks_tflops = {} for n in range(NUM_STEPS):