[pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B (#1483)

* support p2p communication with any type of object | pass test

* reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [pipeline/rpc] implement a demo for PP with cuda rpc framework

* [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B
pull/1486/head
Kirigaya Kazuto 2022-08-24 11:19:46 +08:00 committed by GitHub
parent d6e3dca436
commit a6c8749198
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 366 additions and 139 deletions

View File

@ -1,7 +1,7 @@
import threading
from enum import Enum
from typing import List, Any, Tuple, Dict
from abc import ABC, abstractmethod
from abc import ABC
import torch
from torch import nn
@ -18,9 +18,8 @@ use_color_debug = False
use_progress = False
# TODO:
# 1. design a unique_key without node.name (Maybe I can use combination of microbatch_id and stage_id)
# 2. use waiting list to contain the uncomplete WorkItem
# 3. think about the representation of the order of args and kwargs
# 1. replace world_size with other parameters
# 2. adjust to args and kwargs
def color_debug(text, prefix=' ', color='blue'):
@ -126,33 +125,32 @@ class RemoteOptimizer:
class Worker:
def __init__(self,
cur_rank_module: nn.Module,
rank: int,
world_size: int,
module_partition: nn.Module,
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
max_outstanding: int,
device: str,
checkpoint: bool = False) -> None:
super().__init__()
self.rank = rank
self.world_size = world_size
self.pp_rank = pp_rank
self.actual_stage_num = actual_stage_num
self.num_microbatches = num_microbatches
self.max_outstanding = max_outstanding
self.outstanding = 0
self.checkpoint = checkpoint
if device == 'cuda':
device = f'cuda:{rank}'
self.device = device
self.future_devices = None if device is None or device == 'cpu' else [device]
self.stage_to_worker_rref: Dict[int, PyRRef] = None
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = None
self.producer_stage_ids: List[int] = None
self.consumer_stage_ids: List[int] = None
# module
self.cur_rank_module = cur_rank_module.to(device)
self.module_partition = module_partition.to(device)
self.debug_list = [None] * num_microbatches
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
@ -164,16 +162,16 @@ class Worker:
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{rank}', daemon=True)
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
self.main_loop_thread.start()
def _get_future_by_device(self):
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
def sync_global_worker_rrefs(self, stage_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.stage_to_worker_rref is None, f"in rank {self.rank}, worker has sync global workers rrefs"
assert stage_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
self.stage_to_worker_rref = stage_to_worker_rref
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
def get_output_by_key(self, key: UniqueKey) -> Any:
with self.output_list_condition_lock:
@ -183,7 +181,7 @@ class Worker:
output_work_item = self.output_list[key]
output = output_work_item.output.wait()
# color_debug(f'rank {self.rank}, output {type(output)}', 'get output', 'red')
# color_debug(f'rank {self.pp_rank}, output {type(output)}', 'get output', 'red')
output_work_item.refcount += 1
# all consumers have been satisfied, the work_item can be released
@ -193,8 +191,13 @@ class Worker:
return output
# just for first rank
# TODO : input is args kwargs
def get_parameters(self) -> List[torch.Tensor]:
return [p for p in self.module_partition.parameters()]
def get_parameter_gradients(self) -> List[torch.Tensor]:
return [p.grad for p in self.module_partition.parameters()]
# just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any]):
with self.work_list_condition_lock:
assert self.consumer_stage_ids is not None
@ -203,16 +206,15 @@ class Worker:
output = self._get_future_by_device()
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
work_item = WorkItem(self.rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches,
consumer_num)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, consumer_num)
self.work_list[key] = work_item
color_debug(f'rank {self.rank} receive data from dataloader', 'data dispatch', 'magenta')
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
# just for last rank
# TODO : write a function to add gradient to work_list and see if there is contradictory
# just for last pp_rank
def _begin_backward(self, microbatch_id: int):
with self.work_list_condition_lock:
assert self.producer_stage_ids is not None
@ -221,10 +223,10 @@ class Worker:
output = self._get_future_by_device()
grad_wrt_loss = torch.tensor(1, device=self.device)
work_item = WorkItem(self.rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, producer_num)
color_debug(f'rank {self.rank} propose backward', 'data dispatch', 'magenta')
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
@ -238,7 +240,7 @@ class Worker:
consumer_num = len(self.consumer_stage_ids)
assert producer_num > 0, "only stage that has producers can subscribe producers"
stage_id = self.rank
stage_id = self.pp_rank
subscribe_forward_futures: List[Future] = [None] * producer_num
output = self._get_future_by_device()
@ -246,10 +248,10 @@ class Worker:
for i in range(producer_num):
producer_stage_id = self.producer_stage_ids[i]
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
producer_worker_rref = self.stage_to_worker_rref[producer_stage_id]
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
color_debug(f'rank {self.rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
'magenta')
args = []
@ -261,14 +263,14 @@ class Worker:
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, consumer_num)
color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_producer
color_debug(
f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
@ -282,18 +284,18 @@ class Worker:
assert consumer_num > 0, "only stage that has consumers can subscribe comsumers"
# TODO : is this right?
stage_id = self.rank
stage_id = self.pp_rank
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device()
color_debug(f'rank {self.rank} get {len(subscribe_backward_futures)} futs from its consumer', 'data dispatch',
'magenta')
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
'data dispatch', 'magenta')
for i in range(consumer_num):
consumer_stage_id = self.consumer_stage_ids[i]
consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD)
consumer_worker_rref = self.stage_to_worker_rref[consumer_stage_id]
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key)
args = []
@ -305,7 +307,7 @@ class Worker:
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None,
self.num_microbatches, producer_num)
color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
color_debug(f'rank {self.pp_rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta')
# add work_item to work_list
with self.work_list_condition_lock:
@ -313,13 +315,12 @@ class Worker:
assert key not in self.work_list
self.work_list[key] = work_item_from_consumer
color_debug(
f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
# TODO : fit in any type of partition of network
def _get_producer_consumer(self) -> None:
rank = self.rank
rank = self.pp_rank
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
@ -332,34 +333,41 @@ class Worker:
next_rank = rank + 1
if prev_rank >= 0:
self.producer_stage_ids.append(prev_rank)
if next_rank <= self.world_size - 1:
if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank)
def _skip_forward(self, work_item_phase: Phase) -> bool:
if work_item_phase == Phase.FORWARD and \
self.max_outstanding is not None and \
self.outstanding >= self.max_outstanding:
return True
return False
def _get_work_item_key(self) -> UniqueKey:
with self.work_list_condition_lock:
while len(self.work_list) == 0:
self.work_list_condition_lock.wait()
# execute backward first (if backward phase in work_list)
select_work_list_key = None
for key in self.work_list:
work_item = self.work_list[key]
if work_item.phase == Phase.BACKWARD:
return key
if self._skip_forward(work_item.phase):
if work_item.phase == Phase.FORWARD and \
self.max_outstanding is not None and \
self.outstanding >= self.max_outstanding:
continue
else:
select_work_list_key = key
if select_work_list_key is not None and \
select_work_list_key.phase == Phase.FORWARD and \
key.phase == Phase.BACKWARD:
continue
if select_work_list_key is None:
select_work_list_key = key
else:
phase_pair = (select_work_list_key.phase, key.phase)
# choose forward first
if phase_pair == (Phase.BACKWARD, Phase.FORWARD):
select_work_list_key = key
elif phase_pair == (Phase.FORWARD, Phase.BACKWARD):
continue
# choose work_item which has a smaller microbactch_id first
elif key.microbatch_id < select_work_list_key.microbatch_id:
select_work_list_key = key
return select_work_list_key
def _consume_work_item_by_phase(self, work_item: WorkItem):
@ -369,7 +377,10 @@ class Worker:
microbatch_id = work_item.microbatch_id
consume_result = None
# color_debug(f'rank_{self.rank} enter consume', 'consume', 'blue')
# if self.pp_rank == 0:
# print(f"I am rank_{self.pp_rank} microbatch_id : {microbatch_id}", work_item.phase, len(self.work_list))
# color_debug(f'rank_{self.pp_rank} enter consume', 'consume', 'blue')
if phase == Phase.FORWARD:
self.outstanding += 1
@ -381,19 +392,20 @@ class Worker:
args[i] = arg_obj.requires_grad_()
# TODO : use process manager to acquire rank info later
is_last_stage = len(self.consumer_stage_ids) == 0
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
# last stage doesn't need to do checkpoint, for it will do backward instantly
if self.checkpoint and not is_last_stage:
with torch.no_grad():
consume_result = self.cur_rank_module(*args, **kwargs)
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = None
stage_inputs = args
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
stage_outputs,
checkpoint=True)
else:
# TODO : replace with *args, **kwargs and ensure the consume_result is a tuple
consume_result = self.cur_rank_module(*args, **kwargs)
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = consume_result
stage_inputs = args
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
@ -415,15 +427,13 @@ class Worker:
stage_inputs = backward_cache.stage_inputs
grad_tensors = args
# color_debug(f'rank_{self.rank} before backward', 'consume', 'yellow')
use_checkpoint = backward_cache.checkpoint
if self.checkpoint:
stage_outputs = [self.cur_rank_module(*stage_inputs)]
if use_checkpoint:
stage_outputs = [self.module_partition(*stage_inputs)]
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# color_debug(f'rank_{self.rank} after backward', 'consume', 'yellow')
# collect grad of input tensor
consume_result = []
for input_node in stage_inputs:
@ -453,7 +463,7 @@ class Worker:
work_item = self.work_list.pop(work_item_key)
color_debug(
f'rank {self.rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}',
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}',
'work loop', 'green')
with self.output_list_condition_lock:
@ -464,11 +474,8 @@ class Worker:
consume_result = self._consume_work_item_by_phase(work_item)
color_debug(
f'rank_{self.rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}',
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}',
'work loop', 'green')
# if work_item.stage_id == 1 and work_item.phase == Phase.BACKWARD:
# from time import sleep
# sleep(5)
work_item.output.set_result(consume_result)
@ -479,11 +486,11 @@ class PipelineEngineBase(ABC, nn.Module):
def __init__(self,
module_partitions,
chunk,
world_size,
stage_num,
num_microbatches,
device: str,
max_outstanding=None,
chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
super().__init__()
@ -492,55 +499,86 @@ class PipelineEngineBase(ABC, nn.Module):
self.num_microbatches = num_microbatches
self.device = device
self.max_outstanding = max_outstanding
self.world_size = world_size
self.stage_num = stage_num
self.checkpoint = checkpoint
self.use_interleave = use_interleave
self.stage_to_worker_rref: Dict[int, PyRRef] = dict()
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
self._check_argument()
self._create_pp_rank_to_rpc_worker_id()
self._init_worker()
def _check_argument(self):
self.virtual_stage_num = self.stage_num * self.chunk
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
assert self.virtual_stage_num == len(
self.module_partitions), "stage_num * chunk must be equal to length of model partition!"
if self.use_interleave:
assert self.num_microbatches % self.stage_num == 0, "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
def _get_actual_stage_num(self):
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
def _create_pp_rank_to_rpc_worker_id(self):
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3.
stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1
"""
stage_num = self.stage_num
actual_stage_num = self._get_actual_stage_num()
self.pp_rank_to_rpc_worker_id = [0] * actual_stage_num
for pp_rank in range(actual_stage_num):
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num
def _init_worker(self):
world_size = self.world_size
actual_stage_num = self._get_actual_stage_num()
max_outstanding = self.max_outstanding
checkpoint = self.checkpoint
num_microbatches = self.num_microbatches
device = self.device
# TODO : world size is correct ?
for rank in range(world_size):
cur_rank_module = self.module_partitions[rank]
self.stage_to_worker_rref[rank] = rpc.remote(rank,
Worker,
args=(cur_rank_module, rank, world_size, num_microbatches,
max_outstanding, device, checkpoint))
for pp_rank in range(actual_stage_num):
module_partition = self.module_partitions[pp_rank]
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
if device[:4] == 'cuda':
device = f'cuda:{rpc_worker_id}'
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
Worker,
args=(module_partition, pp_rank, actual_stage_num,
num_microbatches, max_outstanding, device,
checkpoint))
# let each worker know global worker rref (include itself)
for rank in range(world_size):
self.stage_to_worker_rref[rank].rpc_sync().sync_global_worker_rrefs(self.stage_to_worker_rref)
for pp_rank in range(actual_stage_num):
self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
@abstractmethod
def forward_backward(self):
pass
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
parameters = {}
for stage_id in self.pp_rank_to_worker_rref:
parameters[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id]
for p in worker_rref.rpc_sync().get_parameters():
parameters[stage_id].append(p)
return parameters
def remote_grad(self) -> Dict[int, List[torch.Tensor]]:
grads = {}
for stage_id in self.pp_rank_to_worker_rref:
grads[stage_id] = []
worker_rref = self.pp_rank_to_worker_rref[stage_id]
for grad in worker_rref.rpc_sync().get_parameter_gradients():
grads[stage_id].append(grad)
return grads
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions,
chunk,
world_size,
num_microbatches,
device: str,
max_outstanding=None,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding,
use_interleave, checkpoint)
# TODO : adjust to args and kwargs
def forward_backward(self, batch: torch.Tensor):
first_stage_worker = self.stage_to_worker_rref[0]
first_stage_worker = self.pp_rank_to_worker_rref[0]
microbatch_size = len(batch) // self.num_microbatches
actual_stage_num = self._get_actual_stage_num()
microbatch_iter = range(self.num_microbatches)
if use_progress:
@ -550,31 +588,63 @@ class FillDrainPipelineEngine(PipelineEngineBase):
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
# forward subscribe asynchronously
for rank in range(1, self.world_size, 1):
worker_rref = self.stage_to_worker_rref[rank]
for pp_rank in range(1, actual_stage_num, 1):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_async().subscribe_producer(microbatch_id)
# backward subscribe asynchronously
for rank in range(self.world_size - 2, -1, -1):
worker_rref = self.stage_to_worker_rref[rank]
for pp_rank in range(actual_stage_num - 2, -1, -1):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_async().subscribe_consumer(microbatch_id)
# run one microbatch
first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch)
# wait forward
# TODO : all the node to output
forward_result = None
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
for microbatch_id in range(self.num_microbatches):
key = UniqueKey(microbatch_id, Phase.FORWARD)
ret = last_worker_rref.rpc_sync().get_output_by_key(key)
if forward_result is None:
forward_result = [[]] * len(ret)
for i in range(len(forward_result)):
forward_result[i].append(ret[i])
class OneFOneBPipelineEngine(FillDrainPipelineEngine):
# wait for last backward in rank0
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
first_stage_worker.rpc_sync().get_output_by_key(key)
return forward_result
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions,
chunk,
world_size,
num_microbatches,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
max_outstanding = None
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
checkpoint)
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
max_outstanding=None,
chunk: int = 1,
use_interleave: bool = False,
checkpoint: bool = False) -> None:
if max_outstanding is None:
max_outstanding = world_size
super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding,
use_interleave, checkpoint)
max_outstanding = len(module_partitions)
super().__init__(module_partitions, stage_num, num_microbatches, device, max_outstanding, chunk, use_interleave,
checkpoint)

View File

@ -11,14 +11,14 @@ from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOn
class TestModel(nn.Module):
def __init__(self, rank, world_size, feat_num, h) -> None:
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
super().__init__()
self.rank = rank
self.is_last_rank = rank == world_size - 1
self.linear_name = f'linear_{rank}'
if rank == 0:
self.rank = stage_id
self.is_last_rank = stage_id == actual_stage_num - 1
self.linear_name = f'linear_{stage_id}'
if stage_id == 0:
setattr(self, self.linear_name, nn.Linear(feat_num, h))
elif rank == world_size - 1:
elif stage_id == actual_stage_num - 1:
setattr(self, self.linear_name, nn.Linear(h, 1))
else:
setattr(self, self.linear_name, nn.Linear(h, h))
@ -35,32 +35,35 @@ class TestModel(nn.Module):
def run_main(args):
torch.manual_seed(100)
sample_num = 128
feat_num = 10000
h = 10000
device = args.device
world_size = args.world_size
batch_size = 128
stage_num = args.world_size
chunk = args.chunk
num_microbatches = args.num_microbatches
actual_stage_num = stage_num * chunk
use_interleave = args.use_interleave
use_checkpoint = args.use_checkpoint
sample_num = 1024
feat_num = 10
h = 10
batch_size = 1024
assert sample_num % batch_size == 0
batch_num = sample_num // batch_size
num_microbatches = world_size
input_sample = torch.randn((sample_num, feat_num), device=device)
module_partitions = [TestModel(rank, world_size, feat_num, h) for rank in range(world_size)]
module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
chunk=1,
world_size=world_size,
stage_num=stage_num,
num_microbatches=num_microbatches,
device=args.device,
max_outstanding=world_size,
use_interleave=False,
checkpoint=False)
device=device,
chunk=chunk,
use_interleave=use_interleave,
checkpoint=use_checkpoint)
for i in range(batch_num):
batch = input_sample[i * batch_size:(i + 1) * batch_size]
engine.forward_backward(batch)
_ = engine.forward_backward(input_sample)
def run_worker(rank, args):
@ -88,7 +91,11 @@ def run_worker(rank, args):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--use_checkpoint', action='store_true')
parser.add_argument('--use_interleave', action='store_true')
parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29020')
parser.add_argument('--num_worker_threads', type=str, default=128)

View File

@ -0,0 +1,150 @@
import os
import argparse
import torch
from torch import nn
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from torch import autograd
from colorama import Back, Style
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
class TestModel(nn.Module):
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
super().__init__()
self.rank = stage_id
self.is_last_rank = stage_id == actual_stage_num - 1
self.linear_name = f'linear_{stage_id}'
if stage_id == 0:
setattr(self, self.linear_name, nn.Linear(feat_num, h))
elif stage_id == actual_stage_num - 1:
setattr(self, self.linear_name, nn.Linear(h, 1))
else:
setattr(self, self.linear_name, nn.Linear(h, h))
def forward(self, x) -> torch.Tensor:
linear: nn.Module = getattr(self, self.linear_name)
out: torch.Tensor = linear(x)
if self.is_last_rank:
out = out.sum()
return out
def run_main(args):
torch.manual_seed(100)
device = args.device
stage_num = args.world_size
chunk = args.chunk
actual_stage_num = stage_num * chunk
use_interleave = args.use_interleave
use_checkpoint = args.use_checkpoint
sample_num = 1024
feat_num = 100
h = 100
batch_size = 1024
assert sample_num % batch_size == 0
batch_num = sample_num // batch_size
num_microbatches = stage_num * 1
input_sample = torch.randn((sample_num, feat_num), device=device)
module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
use_interleave=use_interleave,
checkpoint=use_checkpoint)
forward_result = engine.forward_backward(input_sample)
cuda_rpc_result = []
single_result = []
actual_stage_num = engine._get_actual_stage_num()
# color_debug('cuda rpc forward', 'Test')
# print(sum(forward_result[0]))
cuda_rpc_result.append(sum(forward_result[0]).item())
# color_debug('cuda rpc backward', 'Test')
grad = engine.remote_grad()
for stage_id in range(actual_stage_num):
for p in grad[stage_id]:
# print(p.sum())
cuda_rpc_result.append(p.sum().item())
test_model = nn.Sequential(*module_partitions).to(device)
input_sample = input_sample.requires_grad_()
out_val = test_model(input_sample).sum()
autograd.backward(out_val)
# color_debug('single forward', 'Test')
# print(out_val)
single_result.append(out_val.item())
# color_debug('single backward', 'Test')
for p in test_model.parameters():
# print(p.grad.sum())
single_result.append(p.grad.sum().item())
cuda_rpc_result = torch.tensor(cuda_rpc_result)
single_result = torch.tensor(single_result)
distance = (cuda_rpc_result - single_result).abs().sum().item()
kappa = round(distance / actual_stage_num, 5)
assert kappa < 0.01, f"kappa({kappa}) is too big, PP result may be incorrect!"
def run_worker(rank, args):
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
# config rpc
# if cuda is used, set_device_map is a must is configured
# for cuda is not supported in torch rpc by default
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads)
world_size = args.world_size
for rank_idx in range(world_size):
options.set_device_map(f'work{rank_idx}', {rank: rank_idx})
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
# in rpc mode, only rank 0 is needed to be coded
if rank == 0:
run_main(args)
# barrier here
rpc.shutdown()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--use_checkpoint', action='store_true')
parser.add_argument('--use_interleave', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29020')
parser.add_argument('--num_worker_threads', type=str, default=128)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
world_size = args.world_size
assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!"
assert args.device in ['cpu', 'cuda'], "device must be cpu or cuda!"
mp.spawn(run_worker, args=(args,), nprocs=world_size)