polish pp middleware (#2476)

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
pull/2481/head
Ziyue Jiang 2023-01-13 16:56:01 +08:00 committed by GitHub
parent a5dc4253c6
commit fef5c949c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 6 deletions

View File

@ -211,7 +211,7 @@ class WorkerBase(ABC):
refcount = 0
with self.output_list_condition_lock:
if refcount < lifecycle:
if refcount <= lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
@ -390,7 +390,7 @@ class WorkerBase(ABC):
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank)
producer_output_key, rank=self.pp_rank, offsets=offsets)
else:
for i in range(producer_num):

View File

@ -29,9 +29,6 @@ class FillDrainWorker(WorkerBase):
target_key = UniqueKey(target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key

View File

@ -120,7 +120,7 @@ def run_master(args):
logger.info(f'{rank=} numel in the partition:{numel}')
# build optim
pp_engine.initialize_optimizer(HybridAdam, lr=1e-3)
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
ranks_tflops = {}
for n in range(NUM_STEPS):