mirror of https://github.com/hpcaitech/ColossalAI
[pipeline/rpc] update outstanding mechanism | optimize dispatching strategy (#1497)
* support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [pipeline/rpc] implement a demo for PP with cuda rpc framework * [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B * [pipeline/rpc] implement distributed optimizer | test with assert_close * [pipeline/rpc] implement distributed optimizer | test with assert_close * [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy * [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy * [pipeline/rpc] update outstanding mechanism | optimize dispatching strategypull/1502/head
parent
0ed2f46131
commit
5a6fd71f90
|
@ -68,7 +68,7 @@ class UniqueKey:
|
|||
|
||||
class WorkItem:
|
||||
__slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id',
|
||||
'num_microbatches')
|
||||
'num_microbatches', 'forward_only')
|
||||
|
||||
stage_id: int
|
||||
phase: Phase
|
||||
|
@ -81,6 +81,7 @@ class WorkItem:
|
|||
|
||||
batch_id: int
|
||||
num_microbatches: int
|
||||
forward_only: bool
|
||||
|
||||
def __init__(self,
|
||||
stage_id,
|
||||
|
@ -91,6 +92,7 @@ class WorkItem:
|
|||
microbatch_id,
|
||||
batch_id,
|
||||
num_microbatches,
|
||||
forward_only,
|
||||
refcount=0) -> None:
|
||||
for attr_name in self.__slots__:
|
||||
setattr(self, attr_name, locals()[attr_name])
|
||||
|
@ -129,36 +131,39 @@ class Worker:
|
|||
pp_rank: int,
|
||||
actual_stage_num: int,
|
||||
num_microbatches: int,
|
||||
max_outstanding: int,
|
||||
use_1F1B: bool,
|
||||
device: str,
|
||||
checkpoint: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.pp_rank = pp_rank
|
||||
self.actual_stage_num = actual_stage_num
|
||||
self.num_microbatches = num_microbatches
|
||||
self.max_outstanding = max_outstanding
|
||||
self.outstanding = 0
|
||||
self.checkpoint = checkpoint
|
||||
self.device = device
|
||||
self.outstanding_range = self._initialize_outstanding_range(pp_rank, actual_stage_num, use_1F1B)
|
||||
|
||||
self.future_devices = None if device is None or device == 'cpu' else [device]
|
||||
# variable and const for context managment
|
||||
self.outstanding = 0
|
||||
self.forward_times = 0
|
||||
self.backward_times = 0
|
||||
self.reset_key = UniqueKey(0, Phase.FORWARD)
|
||||
|
||||
# rref of other workers
|
||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
|
||||
|
||||
# topology info
|
||||
self.producer_stage_ids: List[int] = None
|
||||
self.consumer_stage_ids: List[int] = None
|
||||
|
||||
# module
|
||||
# module partitions
|
||||
self.module_partition = module_partition.to(device)
|
||||
|
||||
self.debug_list = [None] * num_microbatches
|
||||
|
||||
# container to maintain loop
|
||||
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
|
||||
|
||||
self.work_list: Dict[UniqueKey, WorkItem] = dict()
|
||||
self.output_list: Dict[UniqueKey, WorkItem] = dict()
|
||||
|
||||
# Why must a Lock instead of RLock ?
|
||||
# Because RLock cannot be pickled
|
||||
# lock for the list
|
||||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||
|
||||
|
@ -168,6 +173,15 @@ class Worker:
|
|||
def _get_future_by_device(self):
|
||||
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
|
||||
|
||||
def _initialize_outstanding_range(self, pp_rank: int, actual_stage_num: int, use_1F1B: bool) -> Tuple[int]:
|
||||
outstanding_range = None
|
||||
if use_1F1B:
|
||||
if pp_rank == actual_stage_num - 1:
|
||||
outstanding_range = (0, 1)
|
||||
else:
|
||||
outstanding_range = (actual_stage_num, actual_stage_num)
|
||||
return outstanding_range
|
||||
|
||||
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
||||
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
||||
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
||||
|
@ -197,8 +211,15 @@ class Worker:
|
|||
def get_parameter_gradients(self) -> List[torch.Tensor]:
|
||||
return [p.grad for p in self.module_partition.parameters()]
|
||||
|
||||
def reset_pp_context(self):
|
||||
self.forward_times = 0
|
||||
self.backward_times = 0
|
||||
self.outstanding = 0
|
||||
self.microbatch_id_to_backward_cache.clear()
|
||||
self.output_list.clear()
|
||||
|
||||
# just for first pp_rank
|
||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any]):
|
||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||
with self.work_list_condition_lock:
|
||||
assert self.consumer_stage_ids is not None
|
||||
consumer_num = len(self.consumer_stage_ids)
|
||||
|
@ -207,11 +228,10 @@ class Worker:
|
|||
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, consumer_num)
|
||||
self.num_microbatches, forward_only)
|
||||
self.work_list[key] = work_item
|
||||
|
||||
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
|
||||
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# just for last pp_rank
|
||||
|
@ -224,24 +244,22 @@ class Worker:
|
|||
grad_wrt_loss = torch.tensor(1, device=self.device)
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, producer_num)
|
||||
self.num_microbatches, False)
|
||||
|
||||
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
|
||||
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def subscribe_producer(self, microbatch_id: int):
|
||||
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
||||
"""
|
||||
You should call this function asynchronously
|
||||
"""
|
||||
assert self.producer_stage_ids is not None
|
||||
producer_num = len(self.producer_stage_ids)
|
||||
consumer_num = len(self.consumer_stage_ids)
|
||||
assert producer_num > 0, "only stage that has producers can subscribe producers"
|
||||
|
||||
stage_id = self.pp_rank
|
||||
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
output = self._get_future_by_device()
|
||||
|
||||
|
@ -259,9 +277,8 @@ class Worker:
|
|||
producer_args = subscribe_forward_futures[i].wait()
|
||||
args.extend(producer_args)
|
||||
|
||||
# TODO : not only args
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, consumer_num)
|
||||
self.num_microbatches, forward_only)
|
||||
|
||||
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
# add work_item to work_list
|
||||
|
@ -279,13 +296,10 @@ class Worker:
|
|||
You should call this function asynchronously
|
||||
"""
|
||||
assert self.producer_stage_ids is not None
|
||||
producer_num = len(self.producer_stage_ids)
|
||||
consumer_num = len(self.consumer_stage_ids)
|
||||
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
|
||||
|
||||
# TODO : is this right?
|
||||
stage_id = self.pp_rank
|
||||
|
||||
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
||||
output = self._get_future_by_device()
|
||||
|
||||
|
@ -305,7 +319,7 @@ class Worker:
|
|||
|
||||
# flatten args
|
||||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, producer_num)
|
||||
self.num_microbatches, False)
|
||||
|
||||
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
|
||||
|
@ -341,32 +355,57 @@ class Worker:
|
|||
while len(self.work_list) == 0:
|
||||
self.work_list_condition_lock.wait()
|
||||
|
||||
# execute backward first (if backward phase in work_list)
|
||||
select_work_list_key = None
|
||||
for key in self.work_list:
|
||||
work_item = self.work_list[key]
|
||||
if work_item.phase == Phase.FORWARD and \
|
||||
self.max_outstanding is not None and \
|
||||
self.outstanding >= self.max_outstanding:
|
||||
continue
|
||||
else:
|
||||
if select_work_list_key is not None and \
|
||||
select_work_list_key.phase == Phase.FORWARD and \
|
||||
key.phase == Phase.BACKWARD:
|
||||
continue
|
||||
# each stage must do Key(microbatch_id=0, phase=FORWARD) first
|
||||
# before doing the operation, reset the context first
|
||||
if self.reset_key in self.work_list:
|
||||
self.reset_pp_context()
|
||||
|
||||
if select_work_list_key is None:
|
||||
select_work_list_key = key
|
||||
else:
|
||||
phase_pair = (select_work_list_key.phase, key.phase)
|
||||
# choose forward first
|
||||
if phase_pair == (Phase.BACKWARD, Phase.FORWARD):
|
||||
select_work_list_key = key
|
||||
elif phase_pair == (Phase.FORWARD, Phase.BACKWARD):
|
||||
continue
|
||||
# choose work_item which has a smaller microbactch_id first
|
||||
elif key.microbatch_id < select_work_list_key.microbatch_id:
|
||||
select_work_list_key = key
|
||||
# execute backward first (if backward phase in work_list)
|
||||
pp_rank = self.pp_rank
|
||||
actual_stage_num = self.actual_stage_num
|
||||
num_microbatches = self.num_microbatches
|
||||
is_last_stage = pp_rank == actual_stage_num - 1
|
||||
select_work_list_key: UniqueKey = None
|
||||
|
||||
if self.outstanding_range:
|
||||
if self.outstanding <= self.outstanding_range[0]:
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
elif self.outstanding >= self.outstanding_range[1]:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
else:
|
||||
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
|
||||
|
||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||
if target_key in self.work_list:
|
||||
select_work_list_key = target_key
|
||||
|
||||
# change outstanding_range at:
|
||||
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
||||
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
||||
if not is_last_stage and \
|
||||
select_work_list_key is not None and \
|
||||
select_work_list_key.phase == Phase.FORWARD:
|
||||
if select_work_list_key.microbatch_id == actual_stage_num - 1:
|
||||
outstanding_min = actual_stage_num - pp_rank - 1
|
||||
outstanding_max = actual_stage_num - pp_rank
|
||||
self.outstanding_range = (outstanding_min, outstanding_max)
|
||||
elif select_work_list_key.microbatch_id == num_microbatches - 1:
|
||||
self.outstanding_range = (0, 0)
|
||||
|
||||
else:
|
||||
if self.forward_times < num_microbatches:
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
else:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
|
||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||
|
||||
if target_key in self.work_list:
|
||||
select_work_list_key = target_key
|
||||
|
||||
return select_work_list_key
|
||||
|
||||
|
@ -375,15 +414,28 @@ class Worker:
|
|||
args = work_item.args
|
||||
kwargs = work_item.kwargs
|
||||
microbatch_id = work_item.microbatch_id
|
||||
forward_only = work_item.forward_only
|
||||
consume_result = None
|
||||
|
||||
# if self.pp_rank == 0:
|
||||
# print(f"I am rank_{self.pp_rank} microbatch_id : {microbatch_id}", work_item.phase, len(self.work_list))
|
||||
# TODO : use process manager to acquire rank info later
|
||||
is_first_stage = (self.pp_rank == 0)
|
||||
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
|
||||
|
||||
# color_debug(f'rank_{self.pp_rank} enter consume', 'consume', 'blue')
|
||||
# if self.pp_rank == 3:
|
||||
# print(
|
||||
# f'I am rank_{self.pp_rank} microbatch_id : {microbatch_id} {phase} {self._get_store_len()} | {self.outstanding} {self.outstanding_range}'
|
||||
# )
|
||||
|
||||
if phase == Phase.FORWARD:
|
||||
self.outstanding += 1
|
||||
# remind its consumer to get data before forward
|
||||
if not is_last_stage:
|
||||
for stage_id in self.consumer_stage_ids:
|
||||
consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
||||
consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only)
|
||||
self.forward_times += 1
|
||||
|
||||
if not forward_only:
|
||||
self.outstanding += 1
|
||||
|
||||
# TODO : more elegant ?
|
||||
for i in range(len(args)):
|
||||
|
@ -391,35 +443,46 @@ class Worker:
|
|||
if isinstance(arg_obj, torch.Tensor) and not arg_obj.requires_grad:
|
||||
args[i] = arg_obj.requires_grad_()
|
||||
|
||||
# TODO : use process manager to acquire rank info later
|
||||
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
|
||||
|
||||
# last stage doesn't need to do checkpoint, for it will do backward instantly
|
||||
if self.checkpoint and not is_last_stage:
|
||||
if forward_only:
|
||||
with torch.no_grad():
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
stage_outputs = None
|
||||
stage_inputs = None
|
||||
use_checkpoint = None
|
||||
elif self.checkpoint and not is_last_stage:
|
||||
with torch.no_grad():
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
stage_outputs = None
|
||||
stage_inputs = args
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
|
||||
stage_outputs,
|
||||
checkpoint=True)
|
||||
use_checkpoint = True
|
||||
else:
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
|
||||
stage_outputs = consume_result
|
||||
stage_inputs = args
|
||||
use_checkpoint = False
|
||||
|
||||
if not forward_only:
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
|
||||
stage_outputs,
|
||||
checkpoint=False)
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
consume_result = [consume_result] if isinstance(consume_result, torch.Tensor) else consume_result
|
||||
|
||||
# if it is the last stage, trigger backward automatic
|
||||
if is_last_stage:
|
||||
self._begin_backward(microbatch_id)
|
||||
# if not forward_only, do the backward
|
||||
if not forward_only:
|
||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
self._begin_backward(microbatch_id)
|
||||
|
||||
elif phase == Phase.BACKWARD:
|
||||
# remind its producer to get data before backward
|
||||
if not is_first_stage:
|
||||
for stage_id in self.producer_stage_ids:
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
||||
producer_worker_rref.remote().subscribe_consumer(microbatch_id)
|
||||
self.backward_times += 1
|
||||
self.outstanding -= 1
|
||||
|
||||
assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache"
|
||||
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
|
||||
|
||||
|
@ -445,6 +508,9 @@ class Worker:
|
|||
|
||||
return consume_result
|
||||
|
||||
def _get_store_len(self):
|
||||
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)}'
|
||||
|
||||
# do the main loop to consume ready_list
|
||||
def _work_loop(self):
|
||||
# for init
|
||||
|
@ -461,7 +527,7 @@ class Worker:
|
|||
work_item = self.work_list.pop(work_item_key)
|
||||
|
||||
color_debug(
|
||||
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}',
|
||||
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
|
||||
'work loop', 'green')
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
|
@ -472,7 +538,7 @@ class Worker:
|
|||
consume_result = self._consume_work_item_by_phase(work_item)
|
||||
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}',
|
||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()}',
|
||||
'work loop', 'green')
|
||||
work_item.output.set_result(consume_result)
|
||||
|
||||
|
@ -489,9 +555,6 @@ class Worker:
|
|||
self.optimizer.zero_grad()
|
||||
|
||||
|
||||
# TODO
|
||||
# 1. chunk
|
||||
# 2. checkpoint
|
||||
class PipelineEngineBase(ABC, nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
@ -499,19 +562,18 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
stage_num,
|
||||
num_microbatches,
|
||||
device: str,
|
||||
max_outstanding=None,
|
||||
use_1F1B=False,
|
||||
chunk: int = 1,
|
||||
use_interleave: bool = False,
|
||||
checkpoint: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.module_partitions: List[nn.Module] = module_partitions
|
||||
self.chunk = chunk
|
||||
self.num_microbatches = num_microbatches
|
||||
self.device = device
|
||||
self.max_outstanding = max_outstanding
|
||||
self.use_1F1B = use_1F1B
|
||||
self.stage_num = stage_num
|
||||
self.checkpoint = checkpoint
|
||||
self.use_interleave = use_interleave
|
||||
self.use_interleave = chunk > 1
|
||||
|
||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
|
||||
|
||||
|
@ -547,7 +609,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
def _init_worker(self):
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
|
||||
max_outstanding = self.max_outstanding
|
||||
use_1F1B = self.use_1F1B
|
||||
checkpoint = self.checkpoint
|
||||
num_microbatches = self.num_microbatches
|
||||
device = self.device
|
||||
|
@ -560,8 +622,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
|
||||
Worker,
|
||||
args=(module_partition, pp_rank, actual_stage_num,
|
||||
num_microbatches, max_outstanding, device,
|
||||
checkpoint))
|
||||
num_microbatches, use_1F1B, device, checkpoint))
|
||||
|
||||
# let each worker know global worker rref (include itself)
|
||||
for pp_rank in range(actual_stage_num):
|
||||
|
@ -585,46 +646,55 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
grads[stage_id].append(grad)
|
||||
return grads
|
||||
|
||||
def forward_backward(self, batch: torch.Tensor):
|
||||
first_stage_worker = self.pp_rank_to_worker_rref[0]
|
||||
microbatch_size = len(batch) // self.num_microbatches
|
||||
def forward_backward(self, batch: torch.Tensor, forward_only: bool = False):
|
||||
num_microbatches = self.num_microbatches
|
||||
microbatch_size = len(batch) // num_microbatches
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
|
||||
microbatch_iter = range(self.num_microbatches)
|
||||
first_stage_worker = self.pp_rank_to_worker_rref[0]
|
||||
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
|
||||
|
||||
microbatch_iter = range(num_microbatches)
|
||||
if use_progress:
|
||||
microbatch_iter = tqdm(microbatch_iter)
|
||||
|
||||
ret_future: List[Future] = [None] * num_microbatches
|
||||
from time import sleep
|
||||
|
||||
for microbatch_id in microbatch_iter:
|
||||
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
||||
|
||||
# forward subscribe asynchronously
|
||||
for pp_rank in range(1, actual_stage_num, 1):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_async().subscribe_producer(microbatch_id)
|
||||
|
||||
# backward subscribe asynchronously
|
||||
for pp_rank in range(actual_stage_num - 2, -1, -1):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_async().subscribe_consumer(microbatch_id)
|
||||
# control data input speed
|
||||
# to prevent exceed of wait limitations
|
||||
if microbatch_id >= actual_stage_num:
|
||||
if forward_only or not self.use_1F1B:
|
||||
ret_future[microbatch_id - actual_stage_num].wait()
|
||||
else:
|
||||
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
||||
first_stage_worker.rpc_sync().get_output_by_key(key)
|
||||
|
||||
# run one microbatch
|
||||
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch)
|
||||
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch, forward_only)
|
||||
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key)
|
||||
|
||||
# wait forward
|
||||
# TODO : all the node to output
|
||||
forward_result = None
|
||||
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
|
||||
|
||||
for microbatch_id in range(self.num_microbatches):
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
ret = last_worker_rref.rpc_sync().get_output_by_key(key)
|
||||
ret = ret_future[microbatch_id].wait()
|
||||
if forward_result is None:
|
||||
forward_result = [[]] * len(ret)
|
||||
for i in range(len(forward_result)):
|
||||
forward_result[i].append(ret[i])
|
||||
|
||||
# wait for last backward in rank0
|
||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||
first_stage_worker.rpc_sync().get_output_by_key(key)
|
||||
if not forward_only:
|
||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||
first_stage_worker.rpc_sync().get_output_by_key(key)
|
||||
return forward_result
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
|
@ -654,11 +724,9 @@ class FillDrainPipelineEngine(PipelineEngineBase):
|
|||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
use_interleave: bool = False,
|
||||
checkpoint: bool = False) -> None:
|
||||
max_outstanding = None
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
|
||||
checkpoint)
|
||||
use_1F1B = False
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint)
|
||||
|
||||
|
||||
class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||
|
@ -668,11 +736,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
|||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
max_outstanding=None,
|
||||
chunk: int = 1,
|
||||
use_interleave: bool = False,
|
||||
checkpoint: bool = False) -> None:
|
||||
if max_outstanding is None:
|
||||
max_outstanding = len(module_partitions)
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
|
||||
checkpoint)
|
||||
use_1F1B = True
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, checkpoint)
|
||||
|
|
|
@ -5,13 +5,9 @@ import torch
|
|||
from torch import nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import autograd
|
||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
||||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
color = color.upper()
|
||||
|
@ -43,13 +39,13 @@ class RpcTestModel(nn.Module):
|
|||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--epoch', type=int, default=1)
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--chunk', type=int, default=1)
|
||||
parser.add_argument('--use_checkpoint', action='store_true')
|
||||
parser.add_argument('--use_interleave', action='store_true')
|
||||
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
|
||||
parser.add_argument('--device', type=str, default='cuda')
|
||||
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29020')
|
||||
parser.add_argument('--num_worker_threads', type=str, default=128)
|
||||
|
|
|
@ -1,13 +1,7 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import autograd
|
||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
||||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
|
@ -21,7 +15,6 @@ def run_master(args):
|
|||
stage_num = args.world_size
|
||||
chunk = args.chunk
|
||||
actual_stage_num = stage_num * chunk
|
||||
use_interleave = args.use_interleave
|
||||
use_checkpoint = args.use_checkpoint
|
||||
num_microbatches = args.num_microbatches
|
||||
optimizer_class = globals()[args.optimizer]
|
||||
|
@ -45,7 +38,6 @@ def run_master(args):
|
|||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
use_interleave=use_interleave,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
engine.initialize_optimizer(optimizer_class, lr=lr)
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
@ -13,12 +8,12 @@ from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
|||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
|
||||
epoch = args.epoch
|
||||
device = args.device
|
||||
stage_num = args.world_size
|
||||
chunk = args.chunk
|
||||
num_microbatches = args.num_microbatches
|
||||
actual_stage_num = stage_num * chunk
|
||||
use_interleave = args.use_interleave
|
||||
use_checkpoint = args.use_checkpoint
|
||||
|
||||
sample_num = 1024
|
||||
|
@ -38,10 +33,10 @@ def run_master(args):
|
|||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
use_interleave=use_interleave,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
_ = engine.forward_backward(input_sample)
|
||||
for _ in range(epoch):
|
||||
_ = engine.forward_backward(input_sample, forward_only=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,12 +1,6 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import autograd
|
||||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
|
@ -20,7 +14,6 @@ def run_master(args):
|
|||
stage_num = args.world_size
|
||||
chunk = args.chunk
|
||||
actual_stage_num = stage_num * chunk
|
||||
use_interleave = args.use_interleave
|
||||
use_checkpoint = args.use_checkpoint
|
||||
num_microbatches = args.num_microbatches
|
||||
|
||||
|
@ -41,7 +34,6 @@ def run_master(args):
|
|||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
use_interleave=use_interleave,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
forward_result = engine.forward_backward(input_sample)
|
||||
|
|
Loading…
Reference in New Issue