mirror of https://github.com/hpcaitech/ColossalAI
[pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B (#1483)
* support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [pipeline/rpc] implement a demo for PP with cuda rpc framework * [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1Bpull/1486/head
parent
d6e3dca436
commit
a6c8749198
|
@ -1,7 +1,7 @@
|
|||
import threading
|
||||
from enum import Enum
|
||||
from typing import List, Any, Tuple, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -18,9 +18,8 @@ use_color_debug = False
|
|||
use_progress = False
|
||||
|
||||
# TODO:
|
||||
# 1. design a unique_key without node.name (Maybe I can use combination of microbatch_id and stage_id)
|
||||
# 2. use waiting list to contain the uncomplete WorkItem
|
||||
# 3. think about the representation of the order of args and kwargs
|
||||
# 1. replace world_size with other parameters
|
||||
# 2. adjust to args and kwargs
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
|
@ -126,33 +125,32 @@ class RemoteOptimizer:
|
|||
class Worker:
|
||||
|
||||
def __init__(self,
|
||||
cur_rank_module: nn.Module,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
module_partition: nn.Module,
|
||||
pp_rank: int,
|
||||
actual_stage_num: int,
|
||||
num_microbatches: int,
|
||||
max_outstanding: int,
|
||||
device: str,
|
||||
checkpoint: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.pp_rank = pp_rank
|
||||
self.actual_stage_num = actual_stage_num
|
||||
self.num_microbatches = num_microbatches
|
||||
self.max_outstanding = max_outstanding
|
||||
self.outstanding = 0
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
if device == 'cuda':
|
||||
device = f'cuda:{rank}'
|
||||
self.device = device
|
||||
|
||||
self.future_devices = None if device is None or device == 'cpu' else [device]
|
||||
|
||||
self.stage_to_worker_rref: Dict[int, PyRRef] = None
|
||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
|
||||
self.producer_stage_ids: List[int] = None
|
||||
self.consumer_stage_ids: List[int] = None
|
||||
|
||||
# module
|
||||
self.cur_rank_module = cur_rank_module.to(device)
|
||||
self.module_partition = module_partition.to(device)
|
||||
|
||||
self.debug_list = [None] * num_microbatches
|
||||
|
||||
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
|
||||
|
||||
|
@ -164,16 +162,16 @@ class Worker:
|
|||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||
|
||||
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{rank}', daemon=True)
|
||||
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
|
||||
self.main_loop_thread.start()
|
||||
|
||||
def _get_future_by_device(self):
|
||||
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
|
||||
|
||||
def sync_global_worker_rrefs(self, stage_to_worker_rref: Dict[int, PyRRef]) -> None:
|
||||
assert self.stage_to_worker_rref is None, f"in rank {self.rank}, worker has sync global workers rrefs"
|
||||
assert stage_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
||||
self.stage_to_worker_rref = stage_to_worker_rref
|
||||
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
||||
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
||||
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
||||
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
|
||||
|
||||
def get_output_by_key(self, key: UniqueKey) -> Any:
|
||||
with self.output_list_condition_lock:
|
||||
|
@ -183,7 +181,7 @@ class Worker:
|
|||
output_work_item = self.output_list[key]
|
||||
|
||||
output = output_work_item.output.wait()
|
||||
# color_debug(f'rank {self.rank}, output {type(output)}', 'get output', 'red')
|
||||
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
|
||||
output_work_item.refcount += 1
|
||||
|
||||
# all consumers have been satisfied, the work_item can be released
|
||||
|
@ -193,8 +191,13 @@ class Worker:
|
|||
|
||||
return output
|
||||
|
||||
# just for first rank
|
||||
# TODO : input is args kwargs
|
||||
def get_parameters(self) -> List[torch.Tensor]:
|
||||
return [p for p in self.module_partition.parameters()]
|
||||
|
||||
def get_parameter_gradients(self) -> List[torch.Tensor]:
|
||||
return [p.grad for p in self.module_partition.parameters()]
|
||||
|
||||
# just for first pp_rank
|
||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any]):
|
||||
with self.work_list_condition_lock:
|
||||
assert self.consumer_stage_ids is not None
|
||||
|
@ -203,16 +206,15 @@ class Worker:
|
|||
output = self._get_future_by_device()
|
||||
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
|
||||
|
||||
work_item = WorkItem(self.rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches,
|
||||
consumer_num)
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, consumer_num)
|
||||
self.work_list[key] = work_item
|
||||
|
||||
color_debug(f'rank {self.rank} receive data from dataloader', 'data dispatch', 'magenta')
|
||||
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
|
||||
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# just for last rank
|
||||
# TODO : write a function to add gradient to work_list and see if there is contradictory
|
||||
# just for last pp_rank
|
||||
def _begin_backward(self, microbatch_id: int):
|
||||
with self.work_list_condition_lock:
|
||||
assert self.producer_stage_ids is not None
|
||||
|
@ -221,10 +223,10 @@ class Worker:
|
|||
output = self._get_future_by_device()
|
||||
grad_wrt_loss = torch.tensor(1, device=self.device)
|
||||
|
||||
work_item = WorkItem(self.rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
||||
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, producer_num)
|
||||
|
||||
color_debug(f'rank {self.rank} propose backward', 'data dispatch', 'magenta')
|
||||
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
|
||||
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
@ -238,7 +240,7 @@ class Worker:
|
|||
consumer_num = len(self.consumer_stage_ids)
|
||||
assert producer_num > 0, "only stage that has producers can subscribe producers"
|
||||
|
||||
stage_id = self.rank
|
||||
stage_id = self.pp_rank
|
||||
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
output = self._get_future_by_device()
|
||||
|
@ -246,10 +248,10 @@ class Worker:
|
|||
for i in range(producer_num):
|
||||
producer_stage_id = self.producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.stage_to_worker_rref[producer_stage_id]
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
|
||||
color_debug(f'rank {self.rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
|
||||
'magenta')
|
||||
|
||||
args = []
|
||||
|
@ -261,14 +263,14 @@ class Worker:
|
|||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, consumer_num)
|
||||
|
||||
color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
# add work_item to work_list
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_producer
|
||||
color_debug(
|
||||
f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
|
@ -282,18 +284,18 @@ class Worker:
|
|||
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
|
||||
|
||||
# TODO : is this right?
|
||||
stage_id = self.rank
|
||||
stage_id = self.pp_rank
|
||||
|
||||
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
||||
output = self._get_future_by_device()
|
||||
|
||||
color_debug(f'rank {self.rank} get {len(subscribe_backward_futures)} futs from its consumer', 'data dispatch',
|
||||
'magenta')
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
|
||||
'data dispatch', 'magenta')
|
||||
|
||||
for i in range(consumer_num):
|
||||
consumer_stage_id = self.consumer_stage_ids[i]
|
||||
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
consumer_worker_rref = self.stage_to_worker_rref[consumer_stage_id]
|
||||
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
|
||||
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
|
||||
|
||||
args = []
|
||||
|
@ -305,7 +307,7 @@ class Worker:
|
|||
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, producer_num)
|
||||
|
||||
color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
|
||||
|
||||
# add work_item to work_list
|
||||
with self.work_list_condition_lock:
|
||||
|
@ -313,13 +315,12 @@ class Worker:
|
|||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_consumer
|
||||
color_debug(
|
||||
f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# TODO : fit in any type of partition of network
|
||||
def _get_producer_consumer(self) -> None:
|
||||
rank = self.rank
|
||||
rank = self.pp_rank
|
||||
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
||||
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
||||
|
||||
|
@ -332,34 +333,41 @@ class Worker:
|
|||
next_rank = rank + 1
|
||||
if prev_rank >= 0:
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
if next_rank <= self.world_size - 1:
|
||||
if next_rank <= self.actual_stage_num - 1:
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
|
||||
def _skip_forward(self, work_item_phase: Phase) -> bool:
|
||||
if work_item_phase == Phase.FORWARD and \
|
||||
self.max_outstanding is not None and \
|
||||
self.outstanding >= self.max_outstanding:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
with self.work_list_condition_lock:
|
||||
while len(self.work_list) == 0:
|
||||
self.work_list_condition_lock.wait()
|
||||
|
||||
# execute backward first (if backward phase in work_list)
|
||||
|
||||
select_work_list_key = None
|
||||
for key in self.work_list:
|
||||
work_item = self.work_list[key]
|
||||
|
||||
if work_item.phase == Phase.BACKWARD:
|
||||
return key
|
||||
|
||||
if self._skip_forward(work_item.phase):
|
||||
if work_item.phase == Phase.FORWARD and \
|
||||
self.max_outstanding is not None and \
|
||||
self.outstanding >= self.max_outstanding:
|
||||
continue
|
||||
else:
|
||||
select_work_list_key = key
|
||||
if select_work_list_key is not None and \
|
||||
select_work_list_key.phase == Phase.FORWARD and \
|
||||
key.phase == Phase.BACKWARD:
|
||||
continue
|
||||
|
||||
if select_work_list_key is None:
|
||||
select_work_list_key = key
|
||||
else:
|
||||
phase_pair = (select_work_list_key.phase, key.phase)
|
||||
# choose forward first
|
||||
if phase_pair == (Phase.BACKWARD, Phase.FORWARD):
|
||||
select_work_list_key = key
|
||||
elif phase_pair == (Phase.FORWARD, Phase.BACKWARD):
|
||||
continue
|
||||
# choose work_item which has a smaller microbactch_id first
|
||||
elif key.microbatch_id < select_work_list_key.microbatch_id:
|
||||
select_work_list_key = key
|
||||
|
||||
return select_work_list_key
|
||||
|
||||
def _consume_work_item_by_phase(self, work_item: WorkItem):
|
||||
|
@ -369,7 +377,10 @@ class Worker:
|
|||
microbatch_id = work_item.microbatch_id
|
||||
consume_result = None
|
||||
|
||||
# color_debug(f'rank_{self.rank} enter consume', 'consume', 'blue')
|
||||
# if self.pp_rank == 0:
|
||||
# print(f"I am rank_{self.pp_rank} microbatch_id : {microbatch_id}", work_item.phase, len(self.work_list))
|
||||
|
||||
# color_debug(f'rank_{self.pp_rank} enter consume', 'consume', 'blue')
|
||||
|
||||
if phase == Phase.FORWARD:
|
||||
self.outstanding += 1
|
||||
|
@ -381,19 +392,20 @@ class Worker:
|
|||
args[i] = arg_obj.requires_grad_()
|
||||
|
||||
# TODO : use process manager to acquire rank info later
|
||||
is_last_stage = len(self.consumer_stage_ids) == 0
|
||||
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
|
||||
|
||||
# last stage doesn't need to do checkpoint, for it will do backward instantly
|
||||
if self.checkpoint and not is_last_stage:
|
||||
with torch.no_grad():
|
||||
consume_result = self.cur_rank_module(*args, **kwargs)
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
stage_outputs = None
|
||||
stage_inputs = args
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
|
||||
stage_outputs,
|
||||
checkpoint=True)
|
||||
else:
|
||||
# TODO : replace with *args, **kwargs and ensure the consume_result is a tuple
|
||||
consume_result = self.cur_rank_module(*args, **kwargs)
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
|
||||
stage_outputs = consume_result
|
||||
stage_inputs = args
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
|
||||
|
@ -415,15 +427,13 @@ class Worker:
|
|||
stage_inputs = backward_cache.stage_inputs
|
||||
grad_tensors = args
|
||||
|
||||
# color_debug(f'rank_{self.rank} before backward', 'consume', 'yellow')
|
||||
use_checkpoint = backward_cache.checkpoint
|
||||
|
||||
if self.checkpoint:
|
||||
stage_outputs = [self.cur_rank_module(*stage_inputs)]
|
||||
if use_checkpoint:
|
||||
stage_outputs = [self.module_partition(*stage_inputs)]
|
||||
|
||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||
|
||||
# color_debug(f'rank_{self.rank} after backward', 'consume', 'yellow')
|
||||
|
||||
# collect grad of input tensor
|
||||
consume_result = []
|
||||
for input_node in stage_inputs:
|
||||
|
@ -453,7 +463,7 @@ class Worker:
|
|||
work_item = self.work_list.pop(work_item_key)
|
||||
|
||||
color_debug(
|
||||
f'rank {self.rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}',
|
||||
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}',
|
||||
'work loop', 'green')
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
|
@ -464,11 +474,8 @@ class Worker:
|
|||
consume_result = self._consume_work_item_by_phase(work_item)
|
||||
|
||||
color_debug(
|
||||
f'rank_{self.rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}',
|
||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}',
|
||||
'work loop', 'green')
|
||||
# if work_item.stage_id == 1 and work_item.phase == Phase.BACKWARD:
|
||||
# from time import sleep
|
||||
# sleep(5)
|
||||
work_item.output.set_result(consume_result)
|
||||
|
||||
|
||||
|
@ -479,11 +486,11 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
|
||||
def __init__(self,
|
||||
module_partitions,
|
||||
chunk,
|
||||
world_size,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device: str,
|
||||
max_outstanding=None,
|
||||
chunk: int = 1,
|
||||
use_interleave: bool = False,
|
||||
checkpoint: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
@ -492,55 +499,86 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
self.num_microbatches = num_microbatches
|
||||
self.device = device
|
||||
self.max_outstanding = max_outstanding
|
||||
self.world_size = world_size
|
||||
self.stage_num = stage_num
|
||||
self.checkpoint = checkpoint
|
||||
self.use_interleave = use_interleave
|
||||
|
||||
self.stage_to_worker_rref: Dict[int, PyRRef] = dict()
|
||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
|
||||
|
||||
self._check_argument()
|
||||
self._create_pp_rank_to_rpc_worker_id()
|
||||
self._init_worker()
|
||||
|
||||
def _check_argument(self):
|
||||
self.virtual_stage_num = self.stage_num * self.chunk
|
||||
|
||||
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
|
||||
assert self.virtual_stage_num == len(
|
||||
self.module_partitions), "stage_num * chunk must be equal to length of model partition!"
|
||||
if self.use_interleave:
|
||||
assert self.num_microbatches % self.stage_num == 0, "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
|
||||
def _get_actual_stage_num(self):
|
||||
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
|
||||
|
||||
def _create_pp_rank_to_rpc_worker_id(self):
|
||||
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
|
||||
e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3.
|
||||
stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
|
||||
of partitions will be moved to device 0 and the others to device 1
|
||||
|
||||
"""
|
||||
stage_num = self.stage_num
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
self.pp_rank_to_rpc_worker_id = [0] * actual_stage_num
|
||||
for pp_rank in range(actual_stage_num):
|
||||
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num
|
||||
|
||||
def _init_worker(self):
|
||||
world_size = self.world_size
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
|
||||
max_outstanding = self.max_outstanding
|
||||
checkpoint = self.checkpoint
|
||||
num_microbatches = self.num_microbatches
|
||||
device = self.device
|
||||
|
||||
# TODO : world size is correct ?
|
||||
for rank in range(world_size):
|
||||
cur_rank_module = self.module_partitions[rank]
|
||||
self.stage_to_worker_rref[rank] = rpc.remote(rank,
|
||||
Worker,
|
||||
args=(cur_rank_module, rank, world_size, num_microbatches,
|
||||
max_outstanding, device, checkpoint))
|
||||
for pp_rank in range(actual_stage_num):
|
||||
module_partition = self.module_partitions[pp_rank]
|
||||
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
|
||||
if device[:4] == 'cuda':
|
||||
device = f'cuda:{rpc_worker_id}'
|
||||
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
|
||||
Worker,
|
||||
args=(module_partition, pp_rank, actual_stage_num,
|
||||
num_microbatches, max_outstanding, device,
|
||||
checkpoint))
|
||||
|
||||
# let each worker know global worker rref (include itself)
|
||||
for rank in range(world_size):
|
||||
self.stage_to_worker_rref[rank].rpc_sync().sync_global_worker_rrefs(self.stage_to_worker_rref)
|
||||
for pp_rank in range(actual_stage_num):
|
||||
self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
|
||||
|
||||
@abstractmethod
|
||||
def forward_backward(self):
|
||||
pass
|
||||
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
|
||||
parameters = {}
|
||||
for stage_id in self.pp_rank_to_worker_rref:
|
||||
parameters[stage_id] = []
|
||||
worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
||||
for p in worker_rref.rpc_sync().get_parameters():
|
||||
parameters[stage_id].append(p)
|
||||
return parameters
|
||||
|
||||
def remote_grad(self) -> Dict[int, List[torch.Tensor]]:
|
||||
grads = {}
|
||||
for stage_id in self.pp_rank_to_worker_rref:
|
||||
grads[stage_id] = []
|
||||
worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
||||
for grad in worker_rref.rpc_sync().get_parameter_gradients():
|
||||
grads[stage_id].append(grad)
|
||||
return grads
|
||||
|
||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
module_partitions,
|
||||
chunk,
|
||||
world_size,
|
||||
num_microbatches,
|
||||
device: str,
|
||||
max_outstanding=None,
|
||||
use_interleave: bool = False,
|
||||
checkpoint: bool = False) -> None:
|
||||
super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding,
|
||||
use_interleave, checkpoint)
|
||||
|
||||
# TODO : adjust to args and kwargs
|
||||
def forward_backward(self, batch: torch.Tensor):
|
||||
first_stage_worker = self.stage_to_worker_rref[0]
|
||||
first_stage_worker = self.pp_rank_to_worker_rref[0]
|
||||
microbatch_size = len(batch) // self.num_microbatches
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
|
||||
microbatch_iter = range(self.num_microbatches)
|
||||
if use_progress:
|
||||
|
@ -550,31 +588,63 @@ class FillDrainPipelineEngine(PipelineEngineBase):
|
|||
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
||||
|
||||
# forward subscribe asynchronously
|
||||
for rank in range(1, self.world_size, 1):
|
||||
worker_rref = self.stage_to_worker_rref[rank]
|
||||
for pp_rank in range(1, actual_stage_num, 1):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_async().subscribe_producer(microbatch_id)
|
||||
|
||||
# backward subscribe asynchronously
|
||||
for rank in range(self.world_size - 2, -1, -1):
|
||||
worker_rref = self.stage_to_worker_rref[rank]
|
||||
for pp_rank in range(actual_stage_num - 2, -1, -1):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_async().subscribe_consumer(microbatch_id)
|
||||
|
||||
# run one microbatch
|
||||
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch)
|
||||
|
||||
# wait forward
|
||||
# TODO : all the node to output
|
||||
forward_result = None
|
||||
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
|
||||
for microbatch_id in range(self.num_microbatches):
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
ret = last_worker_rref.rpc_sync().get_output_by_key(key)
|
||||
if forward_result is None:
|
||||
forward_result = [[]] * len(ret)
|
||||
for i in range(len(forward_result)):
|
||||
forward_result[i].append(ret[i])
|
||||
|
||||
class OneFOneBPipelineEngine(FillDrainPipelineEngine):
|
||||
# wait for last backward in rank0
|
||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||
first_stage_worker.rpc_sync().get_output_by_key(key)
|
||||
return forward_result
|
||||
|
||||
|
||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
module_partitions,
|
||||
chunk,
|
||||
world_size,
|
||||
num_microbatches,
|
||||
module_partitions: List[nn.Module],
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
use_interleave: bool = False,
|
||||
checkpoint: bool = False) -> None:
|
||||
max_outstanding = None
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
|
||||
checkpoint)
|
||||
|
||||
|
||||
class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
module_partitions: List[nn.Module],
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
max_outstanding=None,
|
||||
chunk: int = 1,
|
||||
use_interleave: bool = False,
|
||||
checkpoint: bool = False) -> None:
|
||||
if max_outstanding is None:
|
||||
max_outstanding = world_size
|
||||
super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding,
|
||||
use_interleave, checkpoint)
|
||||
max_outstanding = len(module_partitions)
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
|
||||
checkpoint)
|
||||
|
|
|
@ -11,14 +11,14 @@ from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOn
|
|||
|
||||
class TestModel(nn.Module):
|
||||
|
||||
def __init__(self, rank, world_size, feat_num, h) -> None:
|
||||
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
|
||||
super().__init__()
|
||||
self.rank = rank
|
||||
self.is_last_rank = rank == world_size - 1
|
||||
self.linear_name = f'linear_{rank}'
|
||||
if rank == 0:
|
||||
self.rank = stage_id
|
||||
self.is_last_rank = stage_id == actual_stage_num - 1
|
||||
self.linear_name = f'linear_{stage_id}'
|
||||
if stage_id == 0:
|
||||
setattr(self, self.linear_name, nn.Linear(feat_num, h))
|
||||
elif rank == world_size - 1:
|
||||
elif stage_id == actual_stage_num - 1:
|
||||
setattr(self, self.linear_name, nn.Linear(h, 1))
|
||||
else:
|
||||
setattr(self, self.linear_name, nn.Linear(h, h))
|
||||
|
@ -35,32 +35,35 @@ class TestModel(nn.Module):
|
|||
def run_main(args):
|
||||
torch.manual_seed(100)
|
||||
|
||||
sample_num = 128
|
||||
feat_num = 10000
|
||||
h = 10000
|
||||
device = args.device
|
||||
world_size = args.world_size
|
||||
batch_size = 128
|
||||
stage_num = args.world_size
|
||||
chunk = args.chunk
|
||||
num_microbatches = args.num_microbatches
|
||||
actual_stage_num = stage_num * chunk
|
||||
use_interleave = args.use_interleave
|
||||
use_checkpoint = args.use_checkpoint
|
||||
|
||||
sample_num = 1024
|
||||
feat_num = 10
|
||||
h = 10
|
||||
batch_size = 1024
|
||||
|
||||
assert sample_num % batch_size == 0
|
||||
batch_num = sample_num // batch_size
|
||||
num_microbatches = world_size
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
module_partitions = [TestModel(rank, world_size, feat_num, h) for rank in range(world_size)]
|
||||
module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
|
||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
chunk=1,
|
||||
world_size=world_size,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=args.device,
|
||||
max_outstanding=world_size,
|
||||
use_interleave=False,
|
||||
checkpoint=False)
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
use_interleave=use_interleave,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
for i in range(batch_num):
|
||||
batch = input_sample[i * batch_size:(i + 1) * batch_size]
|
||||
engine.forward_backward(batch)
|
||||
_ = engine.forward_backward(input_sample)
|
||||
|
||||
|
||||
def run_worker(rank, args):
|
||||
|
@ -88,7 +91,11 @@ def run_worker(rank, args):
|
|||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--device', type=str, default='cuda')
|
||||
parser.add_argument('--chunk', type=int, default=1)
|
||||
parser.add_argument('--use_checkpoint', action='store_true')
|
||||
parser.add_argument('--use_interleave', action='store_true')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29020')
|
||||
parser.add_argument('--num_worker_threads', type=str, default=128)
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import autograd
|
||||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
color = color.upper()
|
||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||
|
||||
|
||||
class TestModel(nn.Module):
|
||||
|
||||
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
|
||||
super().__init__()
|
||||
self.rank = stage_id
|
||||
self.is_last_rank = stage_id == actual_stage_num - 1
|
||||
self.linear_name = f'linear_{stage_id}'
|
||||
if stage_id == 0:
|
||||
setattr(self, self.linear_name, nn.Linear(feat_num, h))
|
||||
elif stage_id == actual_stage_num - 1:
|
||||
setattr(self, self.linear_name, nn.Linear(h, 1))
|
||||
else:
|
||||
setattr(self, self.linear_name, nn.Linear(h, h))
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
linear: nn.Module = getattr(self, self.linear_name)
|
||||
out: torch.Tensor = linear(x)
|
||||
|
||||
if self.is_last_rank:
|
||||
out = out.sum()
|
||||
return out
|
||||
|
||||
|
||||
def run_main(args):
|
||||
torch.manual_seed(100)
|
||||
|
||||
device = args.device
|
||||
stage_num = args.world_size
|
||||
chunk = args.chunk
|
||||
actual_stage_num = stage_num * chunk
|
||||
use_interleave = args.use_interleave
|
||||
use_checkpoint = args.use_checkpoint
|
||||
|
||||
sample_num = 1024
|
||||
feat_num = 100
|
||||
h = 100
|
||||
batch_size = 1024
|
||||
|
||||
assert sample_num % batch_size == 0
|
||||
batch_num = sample_num // batch_size
|
||||
|
||||
num_microbatches = stage_num * 1
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
|
||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
use_interleave=use_interleave,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
forward_result = engine.forward_backward(input_sample)
|
||||
|
||||
cuda_rpc_result = []
|
||||
single_result = []
|
||||
actual_stage_num = engine._get_actual_stage_num()
|
||||
|
||||
# color_debug('cuda rpc forward', 'Test')
|
||||
# print(sum(forward_result[0]))
|
||||
cuda_rpc_result.append(sum(forward_result[0]).item())
|
||||
# color_debug('cuda rpc backward', 'Test')
|
||||
grad = engine.remote_grad()
|
||||
for stage_id in range(actual_stage_num):
|
||||
for p in grad[stage_id]:
|
||||
# print(p.sum())
|
||||
cuda_rpc_result.append(p.sum().item())
|
||||
|
||||
test_model = nn.Sequential(*module_partitions).to(device)
|
||||
input_sample = input_sample.requires_grad_()
|
||||
out_val = test_model(input_sample).sum()
|
||||
autograd.backward(out_val)
|
||||
# color_debug('single forward', 'Test')
|
||||
# print(out_val)
|
||||
single_result.append(out_val.item())
|
||||
# color_debug('single backward', 'Test')
|
||||
for p in test_model.parameters():
|
||||
# print(p.grad.sum())
|
||||
single_result.append(p.grad.sum().item())
|
||||
|
||||
cuda_rpc_result = torch.tensor(cuda_rpc_result)
|
||||
single_result = torch.tensor(single_result)
|
||||
distance = (cuda_rpc_result - single_result).abs().sum().item()
|
||||
kappa = round(distance / actual_stage_num, 5)
|
||||
assert kappa < 0.01, f"kappa({kappa}) is too big, PP result may be incorrect!"
|
||||
|
||||
|
||||
def run_worker(rank, args):
|
||||
os.environ['MASTER_ADDR'] = args.master_addr
|
||||
os.environ['MASTER_PORT'] = args.master_port
|
||||
|
||||
# config rpc
|
||||
# if cuda is used, set_device_map is a must is configured
|
||||
# for cuda is not supported in torch rpc by default
|
||||
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads)
|
||||
|
||||
world_size = args.world_size
|
||||
for rank_idx in range(world_size):
|
||||
options.set_device_map(f'work{rank_idx}', {rank: rank_idx})
|
||||
|
||||
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
|
||||
|
||||
# in rpc mode, only rank 0 is needed to be coded
|
||||
if rank == 0:
|
||||
run_main(args)
|
||||
# barrier here
|
||||
rpc.shutdown()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--chunk', type=int, default=1)
|
||||
parser.add_argument('--use_checkpoint', action='store_true')
|
||||
parser.add_argument('--use_interleave', action='store_true')
|
||||
parser.add_argument('--device', type=str, default='cuda')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29020')
|
||||
parser.add_argument('--num_worker_threads', type=str, default=128)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
world_size = args.world_size
|
||||
assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!"
|
||||
assert args.device in ['cpu', 'cuda'], "device must be cpu or cuda!"
|
||||
mp.spawn(run_worker, args=(args,), nprocs=world_size)
|
Loading…
Reference in New Issue