diff --git a/colossalai/pipeline/rpc/PipelineBase.py b/colossalai/pipeline/rpc/PipelineBase.py index 36a69499e..6c3d0afe5 100644 --- a/colossalai/pipeline/rpc/PipelineBase.py +++ b/colossalai/pipeline/rpc/PipelineBase.py @@ -1,7 +1,7 @@ import threading from enum import Enum from typing import List, Any, Tuple, Dict -from abc import ABC, abstractmethod +from abc import ABC import torch from torch import nn @@ -18,9 +18,8 @@ use_color_debug = False use_progress = False # TODO: -# 1. design a unique_key without node.name (Maybe I can use combination of microbatch_id and stage_id) -# 2. use waiting list to contain the uncomplete WorkItem -# 3. think about the representation of the order of args and kwargs +# 1. replace world_size with other parameters +# 2. adjust to args and kwargs def color_debug(text, prefix=' ', color='blue'): @@ -126,33 +125,32 @@ class RemoteOptimizer: class Worker: def __init__(self, - cur_rank_module: nn.Module, - rank: int, - world_size: int, + module_partition: nn.Module, + pp_rank: int, + actual_stage_num: int, num_microbatches: int, max_outstanding: int, device: str, checkpoint: bool = False) -> None: super().__init__() - self.rank = rank - self.world_size = world_size + self.pp_rank = pp_rank + self.actual_stage_num = actual_stage_num self.num_microbatches = num_microbatches self.max_outstanding = max_outstanding self.outstanding = 0 self.checkpoint = checkpoint - - if device == 'cuda': - device = f'cuda:{rank}' self.device = device self.future_devices = None if device is None or device == 'cpu' else [device] - self.stage_to_worker_rref: Dict[int, PyRRef] = None + self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None self.producer_stage_ids: List[int] = None self.consumer_stage_ids: List[int] = None # module - self.cur_rank_module = cur_rank_module.to(device) + self.module_partition = module_partition.to(device) + + self.debug_list = [None] * num_microbatches self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() @@ -164,16 +162,16 @@ class Worker: self.work_list_condition_lock = threading.Condition(threading.Lock()) self.output_list_condition_lock = threading.Condition(threading.Lock()) - self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{rank}', daemon=True) + self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) self.main_loop_thread.start() def _get_future_by_device(self): return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) - def sync_global_worker_rrefs(self, stage_to_worker_rref: Dict[int, PyRRef]) -> None: - assert self.stage_to_worker_rref is None, f"in rank {self.rank}, worker has sync global workers rrefs" - assert stage_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" - self.stage_to_worker_rref = stage_to_worker_rref + def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: + assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" + assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" + self.pp_rank_to_worker_rref = pp_rank_to_worker_rref def get_output_by_key(self, key: UniqueKey) -> Any: with self.output_list_condition_lock: @@ -183,7 +181,7 @@ class Worker: output_work_item = self.output_list[key] output = output_work_item.output.wait() - # color_debug(f'rank {self.rank}, output {type(output)}', 'get output', 'red') + # color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red') output_work_item.refcount += 1 # all consumers have been satisfied, the work_item can be released @@ -193,8 +191,13 @@ class Worker: return output - # just for first rank - # TODO : input is args kwargs + def get_parameters(self) -> List[torch.Tensor]: + return [p for p in self.module_partition.parameters()] + + def get_parameter_gradients(self) -> List[torch.Tensor]: + return [p.grad for p in self.module_partition.parameters()] + + # just for first pp_rank def set_input(self, microbatch_id: int, microbatch: Tuple[Any]): with self.work_list_condition_lock: assert self.consumer_stage_ids is not None @@ -203,16 +206,15 @@ class Worker: output = self._get_future_by_device() args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch - work_item = WorkItem(self.rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches, - consumer_num) + work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, + self.num_microbatches, consumer_num) self.work_list[key] = work_item - color_debug(f'rank {self.rank} receive data from dataloader', 'data dispatch', 'magenta') + color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() - # just for last rank - # TODO : write a function to add gradient to work_list and see if there is contradictory + # just for last pp_rank def _begin_backward(self, microbatch_id: int): with self.work_list_condition_lock: assert self.producer_stage_ids is not None @@ -221,10 +223,10 @@ class Worker: output = self._get_future_by_device() grad_wrt_loss = torch.tensor(1, device=self.device) - work_item = WorkItem(self.rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, + work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, self.num_microbatches, producer_num) - color_debug(f'rank {self.rank} propose backward', 'data dispatch', 'magenta') + color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta') self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -238,7 +240,7 @@ class Worker: consumer_num = len(self.consumer_stage_ids) assert producer_num > 0, "only stage that has producers can subscribe producers" - stage_id = self.rank + stage_id = self.pp_rank subscribe_forward_futures: List[Future] = [None] * producer_num output = self._get_future_by_device() @@ -246,10 +248,10 @@ class Worker: for i in range(producer_num): producer_stage_id = self.producer_stage_ids[i] producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) - producer_worker_rref = self.stage_to_worker_rref[producer_stage_id] + producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) - color_debug(f'rank {self.rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch', + color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch', 'magenta') args = [] @@ -261,14 +263,14 @@ class Worker: work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches, consumer_num) - color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') + color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') # add work_item to work_list with self.work_list_condition_lock: key = UniqueKey(microbatch_id, Phase.FORWARD) assert key not in self.work_list self.work_list[key] = work_item_from_producer color_debug( - f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}', + f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}', 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() @@ -282,18 +284,18 @@ class Worker: assert consumer_num > 0, "only stage that has consumers can subscribe comsumers" # TODO : is this right? - stage_id = self.rank + stage_id = self.pp_rank subscribe_backward_futures: List[Future] = [None] * consumer_num output = self._get_future_by_device() - color_debug(f'rank {self.rank} get {len(subscribe_backward_futures)} futs from its consumer', 'data dispatch', - 'magenta') + color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer', + 'data dispatch', 'magenta') for i in range(consumer_num): consumer_stage_id = self.consumer_stage_ids[i] consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) - consumer_worker_rref = self.stage_to_worker_rref[consumer_stage_id] + consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key) args = [] @@ -305,7 +307,7 @@ class Worker: work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None, self.num_microbatches, producer_num) - color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') + color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') # add work_item to work_list with self.work_list_condition_lock: @@ -313,13 +315,12 @@ class Worker: assert key not in self.work_list self.work_list[key] = work_item_from_consumer color_debug( - f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}', + f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}', 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() - # TODO : fit in any type of partition of network def _get_producer_consumer(self) -> None: - rank = self.rank + rank = self.pp_rank assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" @@ -332,34 +333,41 @@ class Worker: next_rank = rank + 1 if prev_rank >= 0: self.producer_stage_ids.append(prev_rank) - if next_rank <= self.world_size - 1: + if next_rank <= self.actual_stage_num - 1: self.consumer_stage_ids.append(next_rank) - def _skip_forward(self, work_item_phase: Phase) -> bool: - if work_item_phase == Phase.FORWARD and \ - self.max_outstanding is not None and \ - self.outstanding >= self.max_outstanding: - return True - return False - def _get_work_item_key(self) -> UniqueKey: with self.work_list_condition_lock: while len(self.work_list) == 0: self.work_list_condition_lock.wait() # execute backward first (if backward phase in work_list) - select_work_list_key = None for key in self.work_list: work_item = self.work_list[key] - - if work_item.phase == Phase.BACKWARD: - return key - - if self._skip_forward(work_item.phase): + if work_item.phase == Phase.FORWARD and \ + self.max_outstanding is not None and \ + self.outstanding >= self.max_outstanding: continue else: - select_work_list_key = key + if select_work_list_key is not None and \ + select_work_list_key.phase == Phase.FORWARD and \ + key.phase == Phase.BACKWARD: + continue + + if select_work_list_key is None: + select_work_list_key = key + else: + phase_pair = (select_work_list_key.phase, key.phase) + # choose forward first + if phase_pair == (Phase.BACKWARD, Phase.FORWARD): + select_work_list_key = key + elif phase_pair == (Phase.FORWARD, Phase.BACKWARD): + continue + # choose work_item which has a smaller microbactch_id first + elif key.microbatch_id < select_work_list_key.microbatch_id: + select_work_list_key = key + return select_work_list_key def _consume_work_item_by_phase(self, work_item: WorkItem): @@ -369,7 +377,10 @@ class Worker: microbatch_id = work_item.microbatch_id consume_result = None - # color_debug(f'rank_{self.rank} enter consume', 'consume', 'blue') + # if self.pp_rank == 0: + # print(f"I am rank_{self.pp_rank} microbatch_id : {microbatch_id}", work_item.phase, len(self.work_list)) + + # color_debug(f'rank_{self.pp_rank} enter consume', 'consume', 'blue') if phase == Phase.FORWARD: self.outstanding += 1 @@ -381,19 +392,20 @@ class Worker: args[i] = arg_obj.requires_grad_() # TODO : use process manager to acquire rank info later - is_last_stage = len(self.consumer_stage_ids) == 0 + is_last_stage = (self.pp_rank == self.actual_stage_num - 1) + # last stage doesn't need to do checkpoint, for it will do backward instantly if self.checkpoint and not is_last_stage: with torch.no_grad(): - consume_result = self.cur_rank_module(*args, **kwargs) + consume_result = self.module_partition(*args, **kwargs) stage_outputs = None stage_inputs = args self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, stage_outputs, checkpoint=True) else: - # TODO : replace with *args, **kwargs and ensure the consume_result is a tuple - consume_result = self.cur_rank_module(*args, **kwargs) + consume_result = self.module_partition(*args, **kwargs) + stage_outputs = consume_result stage_inputs = args self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, @@ -415,15 +427,13 @@ class Worker: stage_inputs = backward_cache.stage_inputs grad_tensors = args - # color_debug(f'rank_{self.rank} before backward', 'consume', 'yellow') + use_checkpoint = backward_cache.checkpoint - if self.checkpoint: - stage_outputs = [self.cur_rank_module(*stage_inputs)] + if use_checkpoint: + stage_outputs = [self.module_partition(*stage_inputs)] autograd.backward(stage_outputs, grad_tensors=grad_tensors) - # color_debug(f'rank_{self.rank} after backward', 'consume', 'yellow') - # collect grad of input tensor consume_result = [] for input_node in stage_inputs: @@ -453,7 +463,7 @@ class Worker: work_item = self.work_list.pop(work_item_key) color_debug( - f'rank {self.rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}', + f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}', 'work loop', 'green') with self.output_list_condition_lock: @@ -464,11 +474,8 @@ class Worker: consume_result = self._consume_work_item_by_phase(work_item) color_debug( - f'rank_{self.rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}', + f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}', 'work loop', 'green') - # if work_item.stage_id == 1 and work_item.phase == Phase.BACKWARD: - # from time import sleep - # sleep(5) work_item.output.set_result(consume_result) @@ -479,11 +486,11 @@ class PipelineEngineBase(ABC, nn.Module): def __init__(self, module_partitions, - chunk, - world_size, + stage_num, num_microbatches, device: str, max_outstanding=None, + chunk: int = 1, use_interleave: bool = False, checkpoint: bool = False) -> None: super().__init__() @@ -492,55 +499,86 @@ class PipelineEngineBase(ABC, nn.Module): self.num_microbatches = num_microbatches self.device = device self.max_outstanding = max_outstanding - self.world_size = world_size + self.stage_num = stage_num self.checkpoint = checkpoint self.use_interleave = use_interleave - self.stage_to_worker_rref: Dict[int, PyRRef] = dict() + self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() + + self._check_argument() + self._create_pp_rank_to_rpc_worker_id() self._init_worker() + def _check_argument(self): + self.virtual_stage_num = self.stage_num * self.chunk + + assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!" + assert self.virtual_stage_num == len( + self.module_partitions), "stage_num * chunk must be equal to length of model partition!" + if self.use_interleave: + assert self.num_microbatches % self.stage_num == 0, "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + + def _get_actual_stage_num(self): + return self.stage_num if self.chunk == 1 else self.virtual_stage_num + + def _create_pp_rank_to_rpc_worker_id(self): + """create a map from model partition to stage_id, which is useful when use_interleave is True. + e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3. + stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part + of partitions will be moved to device 0 and the others to device 1 + + """ + stage_num = self.stage_num + actual_stage_num = self._get_actual_stage_num() + self.pp_rank_to_rpc_worker_id = [0] * actual_stage_num + for pp_rank in range(actual_stage_num): + self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num + def _init_worker(self): - world_size = self.world_size + actual_stage_num = self._get_actual_stage_num() + max_outstanding = self.max_outstanding checkpoint = self.checkpoint num_microbatches = self.num_microbatches device = self.device - # TODO : world size is correct ? - for rank in range(world_size): - cur_rank_module = self.module_partitions[rank] - self.stage_to_worker_rref[rank] = rpc.remote(rank, - Worker, - args=(cur_rank_module, rank, world_size, num_microbatches, - max_outstanding, device, checkpoint)) + for pp_rank in range(actual_stage_num): + module_partition = self.module_partitions[pp_rank] + rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] + if device[:4] == 'cuda': + device = f'cuda:{rpc_worker_id}' + self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, + Worker, + args=(module_partition, pp_rank, actual_stage_num, + num_microbatches, max_outstanding, device, + checkpoint)) # let each worker know global worker rref (include itself) - for rank in range(world_size): - self.stage_to_worker_rref[rank].rpc_sync().sync_global_worker_rrefs(self.stage_to_worker_rref) - - @abstractmethod - def forward_backward(self): - pass - + for pp_rank in range(actual_stage_num): + self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref) + + def remote_parameters(self) -> Dict[int, List[torch.Tensor]]: + parameters = {} + for stage_id in self.pp_rank_to_worker_rref: + parameters[stage_id] = [] + worker_rref = self.pp_rank_to_worker_rref[stage_id] + for p in worker_rref.rpc_sync().get_parameters(): + parameters[stage_id].append(p) + return parameters + + def remote_grad(self) -> Dict[int, List[torch.Tensor]]: + grads = {} + for stage_id in self.pp_rank_to_worker_rref: + grads[stage_id] = [] + worker_rref = self.pp_rank_to_worker_rref[stage_id] + for grad in worker_rref.rpc_sync().get_parameter_gradients(): + grads[stage_id].append(grad) + return grads -class FillDrainPipelineEngine(PipelineEngineBase): - - def __init__(self, - module_partitions, - chunk, - world_size, - num_microbatches, - device: str, - max_outstanding=None, - use_interleave: bool = False, - checkpoint: bool = False) -> None: - super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding, - use_interleave, checkpoint) - - # TODO : adjust to args and kwargs def forward_backward(self, batch: torch.Tensor): - first_stage_worker = self.stage_to_worker_rref[0] + first_stage_worker = self.pp_rank_to_worker_rref[0] microbatch_size = len(batch) // self.num_microbatches + actual_stage_num = self._get_actual_stage_num() microbatch_iter = range(self.num_microbatches) if use_progress: @@ -550,31 +588,63 @@ class FillDrainPipelineEngine(PipelineEngineBase): microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] # forward subscribe asynchronously - for rank in range(1, self.world_size, 1): - worker_rref = self.stage_to_worker_rref[rank] + for pp_rank in range(1, actual_stage_num, 1): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref.rpc_async().subscribe_producer(microbatch_id) # backward subscribe asynchronously - for rank in range(self.world_size - 2, -1, -1): - worker_rref = self.stage_to_worker_rref[rank] + for pp_rank in range(actual_stage_num - 2, -1, -1): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref.rpc_async().subscribe_consumer(microbatch_id) # run one microbatch first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch) + # wait forward + # TODO : all the node to output + forward_result = None + last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1] + for microbatch_id in range(self.num_microbatches): + key = UniqueKey(microbatch_id, Phase.FORWARD) + ret = last_worker_rref.rpc_sync().get_output_by_key(key) + if forward_result is None: + forward_result = [[]] * len(ret) + for i in range(len(forward_result)): + forward_result[i].append(ret[i]) + + # wait for last backward in rank0 + key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) + first_stage_worker.rpc_sync().get_output_by_key(key) + return forward_result + -class OneFOneBPipelineEngine(FillDrainPipelineEngine): +class FillDrainPipelineEngine(PipelineEngineBase): def __init__(self, - module_partitions, - chunk, - world_size, - num_microbatches, + module_partitions: List[nn.Module], + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + use_interleave: bool = False, + checkpoint: bool = False) -> None: + max_outstanding = None + super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave, + checkpoint) + + +class OneFOneBPipelineEngine(PipelineEngineBase): + + def __init__(self, + module_partitions: List[nn.Module], + stage_num: int, + num_microbatches: int, device: str, max_outstanding=None, + chunk: int = 1, use_interleave: bool = False, checkpoint: bool = False) -> None: if max_outstanding is None: - max_outstanding = world_size - super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding, - use_interleave, checkpoint) + max_outstanding = len(module_partitions) + super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave, + checkpoint) diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py index 66b76d136..6608a5c5a 100644 --- a/tests/test_pipeline/test_cuda_rpc_pipeline.py +++ b/tests/test_pipeline/test_cuda_rpc_pipeline.py @@ -11,14 +11,14 @@ from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOn class TestModel(nn.Module): - def __init__(self, rank, world_size, feat_num, h) -> None: + def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: super().__init__() - self.rank = rank - self.is_last_rank = rank == world_size - 1 - self.linear_name = f'linear_{rank}' - if rank == 0: + self.rank = stage_id + self.is_last_rank = stage_id == actual_stage_num - 1 + self.linear_name = f'linear_{stage_id}' + if stage_id == 0: setattr(self, self.linear_name, nn.Linear(feat_num, h)) - elif rank == world_size - 1: + elif stage_id == actual_stage_num - 1: setattr(self, self.linear_name, nn.Linear(h, 1)) else: setattr(self, self.linear_name, nn.Linear(h, h)) @@ -35,32 +35,35 @@ class TestModel(nn.Module): def run_main(args): torch.manual_seed(100) - sample_num = 128 - feat_num = 10000 - h = 10000 device = args.device - world_size = args.world_size - batch_size = 128 + stage_num = args.world_size + chunk = args.chunk + num_microbatches = args.num_microbatches + actual_stage_num = stage_num * chunk + use_interleave = args.use_interleave + use_checkpoint = args.use_checkpoint + + sample_num = 1024 + feat_num = 10 + h = 10 + batch_size = 1024 + assert sample_num % batch_size == 0 batch_num = sample_num // batch_size - num_microbatches = world_size input_sample = torch.randn((sample_num, feat_num), device=device) - module_partitions = [TestModel(rank, world_size, feat_num, h) for rank in range(world_size)] + module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] engine = OneFOneBPipelineEngine(module_partitions=module_partitions, - chunk=1, - world_size=world_size, + stage_num=stage_num, num_microbatches=num_microbatches, - device=args.device, - max_outstanding=world_size, - use_interleave=False, - checkpoint=False) + device=device, + chunk=chunk, + use_interleave=use_interleave, + checkpoint=use_checkpoint) - for i in range(batch_num): - batch = input_sample[i * batch_size:(i + 1) * batch_size] - engine.forward_backward(batch) + _ = engine.forward_backward(input_sample) def run_worker(rank, args): @@ -88,7 +91,11 @@ def run_worker(rank, args): def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--num_microbatches', type=int, default=2) parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--use_checkpoint', action='store_true') + parser.add_argument('--use_interleave', action='store_true') parser.add_argument('--master_addr', type=str, default='localhost') parser.add_argument('--master_port', type=str, default='29020') parser.add_argument('--num_worker_threads', type=str, default=128) diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_pipeline/test_cuda_rpc_value_correctness.py new file mode 100644 index 000000000..0c5f75a12 --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_value_correctness.py @@ -0,0 +1,150 @@ +import os +import argparse + +import torch +from torch import nn +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc +from torch import autograd +from colorama import Back, Style + +from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine + + +def color_debug(text, prefix=' ', color='blue'): + color = color.upper() + print(getattr(Back, color), prefix, Style.RESET_ALL, text) + + +class TestModel(nn.Module): + + def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: + super().__init__() + self.rank = stage_id + self.is_last_rank = stage_id == actual_stage_num - 1 + self.linear_name = f'linear_{stage_id}' + if stage_id == 0: + setattr(self, self.linear_name, nn.Linear(feat_num, h)) + elif stage_id == actual_stage_num - 1: + setattr(self, self.linear_name, nn.Linear(h, 1)) + else: + setattr(self, self.linear_name, nn.Linear(h, h)) + + def forward(self, x) -> torch.Tensor: + linear: nn.Module = getattr(self, self.linear_name) + out: torch.Tensor = linear(x) + + if self.is_last_rank: + out = out.sum() + return out + + +def run_main(args): + torch.manual_seed(100) + + device = args.device + stage_num = args.world_size + chunk = args.chunk + actual_stage_num = stage_num * chunk + use_interleave = args.use_interleave + use_checkpoint = args.use_checkpoint + + sample_num = 1024 + feat_num = 100 + h = 100 + batch_size = 1024 + + assert sample_num % batch_size == 0 + batch_num = sample_num // batch_size + + num_microbatches = stage_num * 1 + + input_sample = torch.randn((sample_num, feat_num), device=device) + + module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] + + engine = OneFOneBPipelineEngine(module_partitions=module_partitions, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + use_interleave=use_interleave, + checkpoint=use_checkpoint) + + forward_result = engine.forward_backward(input_sample) + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # color_debug('cuda rpc forward', 'Test') + # print(sum(forward_result[0])) + cuda_rpc_result.append(sum(forward_result[0]).item()) + # color_debug('cuda rpc backward', 'Test') + grad = engine.remote_grad() + for stage_id in range(actual_stage_num): + for p in grad[stage_id]: + # print(p.sum()) + cuda_rpc_result.append(p.sum().item()) + + test_model = nn.Sequential(*module_partitions).to(device) + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + # color_debug('single forward', 'Test') + # print(out_val) + single_result.append(out_val.item()) + # color_debug('single backward', 'Test') + for p in test_model.parameters(): + # print(p.grad.sum()) + single_result.append(p.grad.sum().item()) + + cuda_rpc_result = torch.tensor(cuda_rpc_result) + single_result = torch.tensor(single_result) + distance = (cuda_rpc_result - single_result).abs().sum().item() + kappa = round(distance / actual_stage_num, 5) + assert kappa < 0.01, f"kappa({kappa}) is too big, PP result may be incorrect!" + + +def run_worker(rank, args): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + # config rpc + # if cuda is used, set_device_map is a must is configured + # for cuda is not supported in torch rpc by default + options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads) + + world_size = args.world_size + for rank_idx in range(world_size): + options.set_device_map(f'work{rank_idx}', {rank: rank_idx}) + + rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + run_main(args) + # barrier here + rpc.shutdown() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--num_microbatches', type=int, default=2) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--use_checkpoint', action='store_true') + parser.add_argument('--use_interleave', action='store_true') + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument('--num_worker_threads', type=str, default=128) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + world_size = args.world_size + assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!" + assert args.device in ['cpu', 'cuda'], "device must be cpu or cuda!" + mp.spawn(run_worker, args=(args,), nprocs=world_size)