[pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera (#1595)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera
pull/1609/head
Kirigaya Kazuto 2022-09-19 11:44:18 +08:00 committed by GitHub
parent c9e8ce67b8
commit edc9e419ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 614 additions and 163 deletions

View File

@ -1,8 +1,9 @@
import threading
from enum import Enum
from typing import List, Any, Tuple, Dict, Callable
from abc import ABC
from abc import ABC, abstractmethod
import sys
import os
import torch
from torch import nn
@ -17,12 +18,10 @@ from time import time
from colorama import Back, Style
# config for debug and test
use_color_debug = False
use_progress = False
use_color_debug = True
# TODO:
# 1. replace world_size with other parameters
# 2. adjust to args and kwargs
# 1. adjust to args and kwargs (pytree)
def color_debug(text, prefix=' ', color='blue'):
@ -137,24 +136,24 @@ class BackwardCache:
setattr(self, arg_name, locals()[arg_name])
class Worker:
class WorkerBase(ABC):
def __init__(self,
module_partition: nn.Module,
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
use_1F1B: bool,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
super().__init__()
self.pp_rank = pp_rank
self.actual_stage_num = actual_stage_num
self.num_microbatches = num_microbatches
self.checkpoint = checkpoint
self.device = device
self.use_1F1B = use_1F1B
self._initialize_outstanding_range()
# variable and const for context managment
@ -172,23 +171,14 @@ class Worker:
# module partitions
self.module_partition = module_partition.to(device)
if criterion:
assert callable(criterion)
self.criterion = criterion
self.metric = metric
# container to maintain loop
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
self.microbatch_id_to_labels: Dict[int, Any] = dict()
self.work_list: Dict[UniqueKey, WorkItem] = dict()
self.output_list: Dict[UniqueKey, WorkItem] = dict()
# context to maintain loop
self._initialize_context_container()
# lock for the list
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
self.step_lock = threading.Lock()
self.step_lock.acquire()
self._initialize_lock()
# main loop
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
@ -199,13 +189,23 @@ class Worker:
def _initialize_outstanding_range(self):
outstanding_range = None
if self.use_1F1B:
if self.pp_rank == self.actual_stage_num - 1:
outstanding_range = (0, 1)
else:
outstanding_range = (self.actual_stage_num, self.actual_stage_num)
if self.pp_rank == self.actual_stage_num - 1:
outstanding_range = (0, 1)
else:
outstanding_range = (self.actual_stage_num, self.actual_stage_num)
self.outstanding_range = outstanding_range
def _initialize_context_container(self):
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
self.microbatch_id_to_labels: Dict[int, Any] = dict()
self.work_list: Dict[UniqueKey, WorkItem] = dict()
self.output_list: Dict[UniqueKey, WorkItem] = dict()
def _initialize_lock(self):
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
@ -241,12 +241,15 @@ class Worker:
forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', 'data dispatch',
'magenta')
self.work_list_condition_lock.notify_all()
# just for last pp_rank
def set_labels(self, microbatch_id: int, microlabels: Any):
self.microbatch_id_to_labels[microbatch_id] = microlabels
with self.label_lock:
self.microbatch_id_to_labels[microbatch_id] = microlabels
self.label_lock.notify_all()
# just for last pp_rank
def _begin_backward(self, microbatch_id: int):
@ -354,51 +357,17 @@ class Worker:
if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank)
@abstractmethod
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
pp_rank = self.pp_rank
actual_stage_num = self.actual_stage_num
num_microbatches = self.num_microbatches
is_last_stage = pp_rank == actual_stage_num - 1
"""
this method control the order of the microbatch to consume
"""
if self.outstanding_range:
if self.outstanding <= self.outstanding_range[0]:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
def is_first_stage(self):
return self.pp_rank == 0
target_key = UniqueKey(target_microbatch_id, target_phase)
# change outstanding_range at:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1:
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
else:
if self.forward_times < num_microbatches:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1
def _consume_work_item_by_phase(self, work_item: WorkItem):
phase = work_item.phase
@ -408,9 +377,8 @@ class Worker:
forward_only = work_item.forward_only
consume_result = None
# TODO : use process manager to acquire rank info later
is_first_stage = (self.pp_rank == 0)
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
is_first_stage = self.is_first_stage()
is_last_stage = self.is_last_stage()
# if self.pp_rank == 0:
# print(
@ -433,6 +401,21 @@ class Worker:
if forward_only:
with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs)
# TODO : integrate output list
if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
labels = self.microbatch_id_to_labels.pop(microbatch_id)
loss: torch.Tensor = self.criterion(consume_result, labels)
if self.metric is not None:
metric_result = self.metric(consume_result, labels)
if isinstance(metric_result, torch.Tensor):
metric_result = metric_result.item()
else:
metric_result = None
consume_result = [loss.item(), metric_result]
stage_outputs = None
stage_inputs = None
use_checkpoint = None
@ -444,10 +427,21 @@ class Worker:
use_checkpoint = True
else:
consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
labels = self.microbatch_id_to_labels.pop(microbatch_id)
loss: torch.Tensor = self.criterion(consume_result, labels)
consume_result = loss.item()
if self.metric is not None:
metric_result = self.metric(consume_result, labels)
if isinstance(metric_result, torch.Tensor):
metric_result = metric_result.item()
else:
metric_result = None
consume_result = [loss.item(), metric_result]
else:
loss = consume_result
@ -486,6 +480,7 @@ class Worker:
if use_checkpoint:
stage_outputs = [self.module_partition(*stage_inputs)]
# overlap recompute and future.wait
grad_tensors = get_real_args(args)
@ -513,11 +508,17 @@ class Worker:
grad_sum += p.grad.sum()
return grad_sum
def _is_first_step(self, work_item) -> bool:
def _is_first_step(self, work_item: WorkItem) -> bool:
return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0
def _is_last_step(self, work_item) -> bool:
return work_item.phase == Phase.BACKWARD and work_item.microbatch_id == self.num_microbatches - 1
def _is_last_step(self, work_item: WorkItem) -> bool:
if work_item.forward_only:
last_phase = Phase.FORWARD
else:
last_phase = Phase.BACKWARD
is_last_phase = work_item.phase == last_phase
is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1
return is_last_phase and is_last_microbatch
# do the main loop to consume ready_list
def _work_loop(self):
@ -551,7 +552,7 @@ class Worker:
# if is last step in one batch reset context and do step
if self._is_last_step(work_item):
if hasattr(self, 'optimizer'):
if hasattr(self, 'optimizer') and not work_item.forward_only:
self.step()
self.forward_times = 0
self.backward_times = 0
@ -560,22 +561,22 @@ class Worker:
def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
self.step_lock = threading.Lock()
self.step_lock.acquire()
def wait_for_step(self):
self.step_lock.acquire()
def step(self):
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
self.optimizer.step()
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
self.optimizer.zero_grad()
self.step_lock.release()
class PipelineEngineBase(ABC, nn.Module):
def __init__(self,
worker_type,
module_partitions,
stage_num,
num_microbatches,
@ -583,17 +584,19 @@ class PipelineEngineBase(ABC, nn.Module):
use_1F1B=False,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
super().__init__()
self.worker_type = worker_type
self.module_partitions: List[nn.Module] = module_partitions
self.chunk = chunk
self.criterion = criterion
self.metric = metric
self.num_microbatches = num_microbatches
self.device = device
self.use_1F1B = use_1F1B
self.stage_num = stage_num
self.checkpoint = checkpoint
self.use_interleave = chunk > 1
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
@ -601,26 +604,24 @@ class PipelineEngineBase(ABC, nn.Module):
self._check_argument()
self._create_pp_rank_to_rpc_worker_id()
self._create_pp_rank_to_module_partition_id()
self._init_worker()
def _check_argument(self):
def _check_argument(self) -> None:
self.virtual_stage_num = self.stage_num * self.chunk
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
assert self.virtual_stage_num == len(
self.module_partitions), "stage_num * chunk must be equal to length of model partition!"
if self.use_interleave:
assert self.num_microbatches % self.stage_num == 0, "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
def _get_actual_stage_num(self):
def _get_actual_stage_num(self) -> int:
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
def _create_pp_rank_to_rpc_worker_id(self):
def _create_pp_rank_to_rpc_worker_id(self) -> None:
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3.
stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1
"""
stage_num = self.stage_num
actual_stage_num = self._get_actual_stage_num()
@ -628,28 +629,39 @@ class PipelineEngineBase(ABC, nn.Module):
for pp_rank in range(actual_stage_num):
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num
def _init_worker(self):
def _create_pp_rank_to_module_partition_id(self) -> None:
"""By default(both fill drain and 1F1B), length of model partitions equal to
actual_stage_num, so allocate model partition to corresponding stage
"""
actual_stage_num = self._get_actual_stage_num()
self.pp_rank_to_module_partition_id = [0] * actual_stage_num
for pp_rank in range(actual_stage_num):
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
def _init_worker(self) -> None:
actual_stage_num = self._get_actual_stage_num()
use_1F1B = self.use_1F1B
worker_type = self.worker_type
checkpoint = self.checkpoint
num_microbatches = self.num_microbatches
device = self.device
criterion = self.criterion
metric = self.metric
for pp_rank in range(actual_stage_num):
module_partition = self.module_partitions[pp_rank]
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
module_partition_id = self.pp_rank_to_module_partition_id[pp_rank]
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
if device[:4] == 'cuda':
device = f'cuda:{rpc_worker_id}'
module_partition = self.module_partitions[module_partition_id]
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
Worker,
worker_type,
args=(module_partition, pp_rank, actual_stage_num,
num_microbatches, use_1F1B, device, criterion,
num_microbatches, device, criterion, metric,
checkpoint))
# let each worker know global worker rref (include itself)
for pp_rank in range(actual_stage_num):
for pp_rank in self.pp_rank_to_worker_rref:
self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
@ -670,65 +682,110 @@ class PipelineEngineBase(ABC, nn.Module):
grads[stage_id].append(grad)
return grads
def get_input_pp_ranks(self) -> List[int]:
return [0]
def get_output_pp_ranks(self) -> List[int]:
return [self._get_actual_stage_num() - 1]
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
actual_stage_num = self._get_actual_stage_num()
use_1F1B = self.use_1F1B
if microbatch_id >= actual_stage_num:
if forward_only or not use_1F1B:
for pp_rank in output_pp_ranks:
ret_future[pp_rank][microbatch_id - actual_stage_num].wait()
else:
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().get_output_by_key(key)
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
return {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
# TODO : add relationship between input_pp_ranks and parts of microbatch
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
# TODO : add relationship between output_pp_ranks and parts of microlabels
worker_rref.remote().set_labels(microbatch_id, microlabels)
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD)
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
ret_future[pp_rank][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
if not forward_only:
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
worker_rref.rpc_sync().get_output_by_key(key)
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
forward_result = []
for pp_rank in output_pp_ranks:
worker_forward_result = [None] * self.num_microbatches
for microbatch_id in range(self.num_microbatches):
ret = ret_future[pp_rank][microbatch_id].wait()
worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result))
forward_result.extend(worker_forward_result)
return forward_result
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
if labels is not None:
assert len(batch) == len(labels)
if not forward_only:
assert hasattr(self, 'optimizer_class')
num_microbatches = self.num_microbatches
microbatch_size = len(batch) // num_microbatches
actual_stage_num = self._get_actual_stage_num()
first_worker_rref = self.pp_rank_to_worker_rref[0]
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
input_pp_ranks = self.get_input_pp_ranks()
output_pp_ranks = self.get_output_pp_ranks()
microbatch_iter = range(num_microbatches)
if use_progress:
microbatch_iter = tqdm(microbatch_iter)
# a cache to collect data and control flow
ret_future = self._create_ret_future(output_pp_ranks)
ret_future: List[Future] = [None] * num_microbatches
for microbatch_id in microbatch_iter:
# control data input speed
for microbatch_id in range(num_microbatches):
# control data input speed
# to prevent exceed of wait limitations
if microbatch_id >= actual_stage_num:
if forward_only or not self.use_1F1B:
ret_future[microbatch_id - actual_stage_num].wait()
else:
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
first_worker_rref.rpc_sync().get_output_by_key(key)
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
# set input
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microbatch = microbatch.cuda()
first_worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only)
# set labels
if not forward_only and labels is not None:
if labels is not None:
microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microlabels = microlabels.cuda()
last_worker_rref.remote().set_labels(microbatch_id, microlabels)
self._set_labels(output_pp_ranks, microbatch_id, microlabels)
key = UniqueKey(microbatch_id, Phase.FORWARD)
ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key)
# get data asynchronously
self._subscribe_forward(microbatch_id, output_pp_ranks, ret_future)
# wait for last backward in rank0
if not forward_only:
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
first_worker_rref.rpc_sync().get_output_by_key(key)
# wait for first rank to ensure all backwards are done
self._ensure_backward(forward_only, input_pp_ranks)
# collect forward result
# TODO : all the node to output
forward_result = None
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
for microbatch_id in range(self.num_microbatches):
key = UniqueKey(microbatch_id, Phase.FORWARD)
ret = ret_future[microbatch_id].wait()
if forward_result is None:
forward_result = [[]] * len(ret)
for i in range(len(forward_result)):
forward_result[i].append(ret[i])
if hasattr(self, 'optimizer_class'):
if not forward_only and labels is not None:
# wait for all step
# TODO : more elegant ?
for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().wait_for_step()
@ -751,31 +808,3 @@ class PipelineEngineBase(ABC, nn.Module):
for fut in self.step_futs:
fut.wait()
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
checkpoint: bool = False) -> None:
use_1F1B = False
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
checkpoint: bool = False) -> None:
use_1F1B = True
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)

View File

@ -0,0 +1,277 @@
from typing import List, Callable, Dict
import torch.nn as nn
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
# <strategy>PipelineEngine is the class for use
class FillDrainWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
num_microbatches = self.num_microbatches
if self.forward_times < num_microbatches:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
use_1F1B = False
super().__init__(FillDrainWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint)
class OneFOneBWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
pp_rank = self.pp_rank
actual_stage_num = self.actual_stage_num
num_microbatches = self.num_microbatches
is_last_stage = pp_rank == actual_stage_num - 1
if self.outstanding <= self.outstanding_range[0]:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
target_key = UniqueKey(target_microbatch_id, target_phase)
# change outstanding_range at:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1:
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
use_1F1B = True
super().__init__(OneFOneBWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint)
class ChimeraWorker(WorkerBase):
def _get_producer_consumer(self) -> None:
rank = self.pp_rank
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
max_pp_rank = min_pp_rank + self.actual_stage_num - 1
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
# should be aranged in order, the order of the input of current forward
self.producer_stage_ids = []
self.consumer_stage_ids = []
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= min_pp_rank:
self.producer_stage_ids.append(prev_rank)
if next_rank <= max_pp_rank:
self.consumer_stage_ids.append(next_rank)
def _get_work_item_key(self) -> UniqueKey:
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
real_microbatch_num = self.num_microbatches // 2
if self.forward_times < real_microbatch_num:
if (pp_rank + 1) % stage_num == 0: # last rank
forward_blocks = self.forward_times // (self.num_microbatches // stage_num)
if forward_blocks > self.backward_times:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else: # others
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
real_target_microbatch_id = target_microbatch_id * 2
if pp_rank >= stage_num:
real_target_microbatch_id += 1
target_key = UniqueKey(real_target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
def is_first_stage(self):
return (self.pp_rank % self.actual_stage_num) == 0
def is_last_stage(self):
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
class ChimeraPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions,
stage_num,
num_microbatches,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
assert num_microbatches % stage_num == 0, \
"In Chimera, num_microbatches must be the multiply of stage_num!"
use_1F1B = False
chunk = 1
super().__init__(ChimeraWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
input_worker_rrefs: List[PyRRef], output_worker_rrefs: List[PyRRef]):
pass
def _create_pp_rank_to_rpc_worker_id(self) -> None:
stage_num = self.stage_num
self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2)
for pp_rank in range(stage_num):
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank
self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1
def _create_pp_rank_to_module_partition_id(self) -> None:
stage_num = self.stage_num
self.pp_rank_to_module_partition_id = [0] * (stage_num * 2)
for pp_rank in range(stage_num):
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
stage_num = self.stage_num
up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks}
# merge up and down
return {**up_ret_future, **down_ret_future}
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
# offset is 0 for all the ranks in up pipeline
# offset is stage_num for all the ranks in down pipeline
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
# offset is 0 for all the ranks in up pipeline
# offset is stage_num for all the ranks in down pipeline
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
worker_rref.remote().set_labels(microbatch_id, microlabels)
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD)
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
stage_num = self.stage_num
num_microbatches = self.num_microbatches
if not forward_only:
for pp_rank in input_pp_ranks:
up_last_microbatch_id = num_microbatches - 2
down_last_microbatch_id = num_microbatches - 1
up_worker_rref = self.pp_rank_to_worker_rref[pp_rank]
down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num]
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
up_worker_rref.rpc_sync().get_output_by_key(up_key)
down_worker_rref.rpc_sync().get_output_by_key(down_key)
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]):
"""Logic of collection of forward in Chimera.
Currently, only one input one output model is supported
"""
stage_num = self.stage_num
forward_result = []
for pp_rank in output_pp_ranks:
worker_forward_result = [None] * self.num_microbatches
for microbatch_id in range(self.num_microbatches):
offset = (microbatch_id % 2) * stage_num
ret = ret_future[pp_rank + offset][microbatch_id].wait()
worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result))
forward_result.extend(worker_forward_result)
return forward_result

BIN
data/cifar-10-python.tar.gz Normal file

Binary file not shown.

View File

@ -0,0 +1,43 @@
import torch
from torch import nn
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
def run_master(args):
torch.manual_seed(100)
epoch = args.epoch
device = args.device
stage_num = 4
chunk = 1
num_microbatches = 4
actual_stage_num = 4
use_checkpoint = False
sample_num = 1024
feat_num = 10
h = 10
batch_size = 1024
assert sample_num % batch_size == 0
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
engine = ChimeraPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
checkpoint=use_checkpoint)
input_sample = torch.randn((sample_num, feat_num), device=device)
for _ in range(epoch):
_ = engine.forward_backward(input_sample, forward_only=False)
if __name__ == "__main__":
args = parse_args()
args.world_size = 4
args.num_microbatches = 4
rpc_run(args, run_master)

View File

@ -3,7 +3,7 @@ from torch import nn
from torch import autograd
from torch.optim import SGD, Adam, RMSprop, Optimizer
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.testing import assert_close
from rpc_test_utils import rpc_run, parse_args, RpcTestModel

View File

@ -0,0 +1,102 @@
import os
from typing import Callable, List, Optional, Type, Union
import time
import pytest
import torch
import torch.nn as nn
from titans.dataloader.cifar10 import build_cifar
from torchvision.models import resnet50
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
from tqdm import tqdm
from rpc_test_utils import rpc_run, parse_args
import colossalai
import colossalai.nn as col_nn
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode
from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
def flatten(x):
return torch.flatten(x, 1)
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, start_dim=1)
def run_master(args):
batch_size = args.batch_size
chunk = args.chunk
device = args.device
world_size = args.world_size
stage_num = world_size
num_microbatches = args.num_microbatches
assert chunk == 1
pipelinable = PipelinableContext()
# build model partitions
with pipelinable:
# input : [B, 3, 32, 32]
model = resnet50()
exec_seq = [
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
]
pipelinable.to_layer_list(exec_seq)
module_partitions: List[PipelinableModel] = [
pipelinable.partition(chunk, stage_num, pp_rank) for pp_rank in range(world_size)
]
# build dataloader
root = os.environ.get('DATA', './data')
train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
criterion = nn.CrossEntropyLoss()
partition_1 = module_partitions[0]
partition_2 = []
for model in module_partitions[1]._module_list:
partition_2.append(model)
partition_2.insert(len(partition_2) - 1, Flatten())
partition_2 = nn.Sequential(*partition_2)
module_partitions = [partition_1, partition_2]
pp_engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
criterion=criterion,
checkpoint=False)
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
s = time.time()
for bx, by in tqdm(train_dataloader):
pp_engine.forward_backward(bx, labels=by, forward_only=False)
cost_time = time.time() - s
print("total cost time :", cost_time)
print("cost time per batch:", cost_time / len(train_dataloader))
@pytest.mark.skip("Test for performance, no need for CI")
def main():
args = parse_args()
# this is due to limitation of partition function
args.world_size = 2
args.chunk = 1
rpc_run(args, run_master)
if __name__ == '__main__':
main()

View File

@ -1,7 +1,7 @@
import torch
from torch import nn
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from rpc_test_utils import rpc_run, parse_args, RpcTestModel

View File

@ -2,7 +2,7 @@ import torch
from torch import nn
from torch import autograd
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.testing import assert_close
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
@ -36,7 +36,7 @@ def run_master(args):
chunk=chunk,
checkpoint=use_checkpoint)
forward_result = engine.forward_backward(input_sample)
forward_result = engine.forward_backward(input_sample)[0]
cuda_rpc_result = []
single_result = []