[pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera (#1595)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera
pull/1609/head
Kirigaya Kazuto 2022-09-19 11:44:18 +08:00 committed by GitHub
parent c9e8ce67b8
commit edc9e419ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 614 additions and 163 deletions

View File

@ -1,8 +1,9 @@
import threading import threading
from enum import Enum from enum import Enum
from typing import List, Any, Tuple, Dict, Callable from typing import List, Any, Tuple, Dict, Callable
from abc import ABC from abc import ABC, abstractmethod
import sys import sys
import os
import torch import torch
from torch import nn from torch import nn
@ -17,12 +18,10 @@ from time import time
from colorama import Back, Style from colorama import Back, Style
# config for debug and test # config for debug and test
use_color_debug = False use_color_debug = True
use_progress = False
# TODO: # TODO:
# 1. replace world_size with other parameters # 1. adjust to args and kwargs (pytree)
# 2. adjust to args and kwargs
def color_debug(text, prefix=' ', color='blue'): def color_debug(text, prefix=' ', color='blue'):
@ -137,24 +136,24 @@ class BackwardCache:
setattr(self, arg_name, locals()[arg_name]) setattr(self, arg_name, locals()[arg_name])
class Worker: class WorkerBase(ABC):
def __init__(self, def __init__(self,
module_partition: nn.Module, module_partition: nn.Module,
pp_rank: int, pp_rank: int,
actual_stage_num: int, actual_stage_num: int,
num_microbatches: int, num_microbatches: int,
use_1F1B: bool,
device: str, device: str,
criterion: Callable = None, criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
super().__init__() super().__init__()
self.pp_rank = pp_rank self.pp_rank = pp_rank
self.actual_stage_num = actual_stage_num self.actual_stage_num = actual_stage_num
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.device = device self.device = device
self.use_1F1B = use_1F1B
self._initialize_outstanding_range() self._initialize_outstanding_range()
# variable and const for context managment # variable and const for context managment
@ -172,23 +171,14 @@ class Worker:
# module partitions # module partitions
self.module_partition = module_partition.to(device) self.module_partition = module_partition.to(device)
if criterion:
assert callable(criterion)
self.criterion = criterion self.criterion = criterion
self.metric = metric
# container to maintain loop # context to maintain loop
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() self._initialize_context_container()
self.microbatch_id_to_labels: Dict[int, Any] = dict()
self.work_list: Dict[UniqueKey, WorkItem] = dict()
self.output_list: Dict[UniqueKey, WorkItem] = dict()
# lock for the list # lock for the list
self.work_list_condition_lock = threading.Condition(threading.Lock()) self._initialize_lock()
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
self.step_lock = threading.Lock()
self.step_lock.acquire()
# main loop # main loop
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
@ -199,13 +189,23 @@ class Worker:
def _initialize_outstanding_range(self): def _initialize_outstanding_range(self):
outstanding_range = None outstanding_range = None
if self.use_1F1B: if self.pp_rank == self.actual_stage_num - 1:
if self.pp_rank == self.actual_stage_num - 1: outstanding_range = (0, 1)
outstanding_range = (0, 1) else:
else: outstanding_range = (self.actual_stage_num, self.actual_stage_num)
outstanding_range = (self.actual_stage_num, self.actual_stage_num)
self.outstanding_range = outstanding_range self.outstanding_range = outstanding_range
def _initialize_context_container(self):
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
self.microbatch_id_to_labels: Dict[int, Any] = dict()
self.work_list: Dict[UniqueKey, WorkItem] = dict()
self.output_list: Dict[UniqueKey, WorkItem] = dict()
def _initialize_lock(self):
self.work_list_condition_lock = threading.Condition(threading.Lock())
self.output_list_condition_lock = threading.Condition(threading.Lock())
self.label_lock = threading.Condition(threading.Lock())
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None: def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs" assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
@ -241,12 +241,15 @@ class Worker:
forward_only) forward_only)
with self.work_list_condition_lock: with self.work_list_condition_lock:
self.work_list[key] = work_item self.work_list[key] = work_item
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta') color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', 'data dispatch',
'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
# just for last pp_rank # just for last pp_rank
def set_labels(self, microbatch_id: int, microlabels: Any): def set_labels(self, microbatch_id: int, microlabels: Any):
self.microbatch_id_to_labels[microbatch_id] = microlabels with self.label_lock:
self.microbatch_id_to_labels[microbatch_id] = microlabels
self.label_lock.notify_all()
# just for last pp_rank # just for last pp_rank
def _begin_backward(self, microbatch_id: int): def _begin_backward(self, microbatch_id: int):
@ -354,51 +357,17 @@ class Worker:
if next_rank <= self.actual_stage_num - 1: if next_rank <= self.actual_stage_num - 1:
self.consumer_stage_ids.append(next_rank) self.consumer_stage_ids.append(next_rank)
@abstractmethod
def _get_work_item_key(self) -> UniqueKey: def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list) """
pp_rank = self.pp_rank this method control the order of the microbatch to consume
actual_stage_num = self.actual_stage_num """
num_microbatches = self.num_microbatches
is_last_stage = pp_rank == actual_stage_num - 1
if self.outstanding_range: def is_first_stage(self):
if self.outstanding <= self.outstanding_range[0]: return self.pp_rank == 0
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
target_key = UniqueKey(target_microbatch_id, target_phase) def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1
# change outstanding_range at:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1:
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
else:
if self.forward_times < num_microbatches:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
def _consume_work_item_by_phase(self, work_item: WorkItem): def _consume_work_item_by_phase(self, work_item: WorkItem):
phase = work_item.phase phase = work_item.phase
@ -408,9 +377,8 @@ class Worker:
forward_only = work_item.forward_only forward_only = work_item.forward_only
consume_result = None consume_result = None
# TODO : use process manager to acquire rank info later is_first_stage = self.is_first_stage()
is_first_stage = (self.pp_rank == 0) is_last_stage = self.is_last_stage()
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
# if self.pp_rank == 0: # if self.pp_rank == 0:
# print( # print(
@ -433,6 +401,21 @@ class Worker:
if forward_only: if forward_only:
with torch.no_grad(): with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
# TODO : integrate output list
if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
labels = self.microbatch_id_to_labels.pop(microbatch_id)
loss: torch.Tensor = self.criterion(consume_result, labels)
if self.metric is not None:
metric_result = self.metric(consume_result, labels)
if isinstance(metric_result, torch.Tensor):
metric_result = metric_result.item()
else:
metric_result = None
consume_result = [loss.item(), metric_result]
stage_outputs = None stage_outputs = None
stage_inputs = None stage_inputs = None
use_checkpoint = None use_checkpoint = None
@ -444,10 +427,21 @@ class Worker:
use_checkpoint = True use_checkpoint = True
else: else:
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
if is_last_stage and self.criterion: if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
labels = self.microbatch_id_to_labels.pop(microbatch_id) labels = self.microbatch_id_to_labels.pop(microbatch_id)
loss: torch.Tensor = self.criterion(consume_result, labels) loss: torch.Tensor = self.criterion(consume_result, labels)
consume_result = loss.item() if self.metric is not None:
metric_result = self.metric(consume_result, labels)
if isinstance(metric_result, torch.Tensor):
metric_result = metric_result.item()
else:
metric_result = None
consume_result = [loss.item(), metric_result]
else: else:
loss = consume_result loss = consume_result
@ -486,6 +480,7 @@ class Worker:
if use_checkpoint: if use_checkpoint:
stage_outputs = [self.module_partition(*stage_inputs)] stage_outputs = [self.module_partition(*stage_inputs)]
# overlap recompute and future.wait # overlap recompute and future.wait
grad_tensors = get_real_args(args) grad_tensors = get_real_args(args)
@ -513,11 +508,17 @@ class Worker:
grad_sum += p.grad.sum() grad_sum += p.grad.sum()
return grad_sum return grad_sum
def _is_first_step(self, work_item) -> bool: def _is_first_step(self, work_item: WorkItem) -> bool:
return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0 return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0
def _is_last_step(self, work_item) -> bool: def _is_last_step(self, work_item: WorkItem) -> bool:
return work_item.phase == Phase.BACKWARD and work_item.microbatch_id == self.num_microbatches - 1 if work_item.forward_only:
last_phase = Phase.FORWARD
else:
last_phase = Phase.BACKWARD
is_last_phase = work_item.phase == last_phase
is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1
return is_last_phase and is_last_microbatch
# do the main loop to consume ready_list # do the main loop to consume ready_list
def _work_loop(self): def _work_loop(self):
@ -551,7 +552,7 @@ class Worker:
# if is last step in one batch reset context and do step # if is last step in one batch reset context and do step
if self._is_last_step(work_item): if self._is_last_step(work_item):
if hasattr(self, 'optimizer'): if hasattr(self, 'optimizer') and not work_item.forward_only:
self.step() self.step()
self.forward_times = 0 self.forward_times = 0
self.backward_times = 0 self.backward_times = 0
@ -560,22 +561,22 @@ class Worker:
def initialize_optimizer(self, optimizer_class: type, **kwargs): def initialize_optimizer(self, optimizer_class: type, **kwargs):
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
self.step_lock = threading.Lock()
self.step_lock.acquire()
def wait_for_step(self): def wait_for_step(self):
self.step_lock.acquire() self.step_lock.acquire()
def step(self): def step(self):
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
self.optimizer.step() self.optimizer.step()
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.step_lock.release() self.step_lock.release()
class PipelineEngineBase(ABC, nn.Module): class PipelineEngineBase(ABC, nn.Module):
def __init__(self, def __init__(self,
worker_type,
module_partitions, module_partitions,
stage_num, stage_num,
num_microbatches, num_microbatches,
@ -583,17 +584,19 @@ class PipelineEngineBase(ABC, nn.Module):
use_1F1B=False, use_1F1B=False,
chunk: int = 1, chunk: int = 1,
criterion: Callable = None, criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
super().__init__() super().__init__()
self.worker_type = worker_type
self.module_partitions: List[nn.Module] = module_partitions self.module_partitions: List[nn.Module] = module_partitions
self.chunk = chunk self.chunk = chunk
self.criterion = criterion self.criterion = criterion
self.metric = metric
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.device = device self.device = device
self.use_1F1B = use_1F1B self.use_1F1B = use_1F1B
self.stage_num = stage_num self.stage_num = stage_num
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.use_interleave = chunk > 1
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
@ -601,26 +604,24 @@ class PipelineEngineBase(ABC, nn.Module):
self._check_argument() self._check_argument()
self._create_pp_rank_to_rpc_worker_id() self._create_pp_rank_to_rpc_worker_id()
self._create_pp_rank_to_module_partition_id()
self._init_worker() self._init_worker()
def _check_argument(self): def _check_argument(self) -> None:
self.virtual_stage_num = self.stage_num * self.chunk self.virtual_stage_num = self.stage_num * self.chunk
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!" assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
assert self.virtual_stage_num == len( assert self.virtual_stage_num == len(
self.module_partitions), "stage_num * chunk must be equal to length of model partition!" self.module_partitions), "stage_num * chunk must be equal to length of model partition!"
if self.use_interleave:
assert self.num_microbatches % self.stage_num == 0, "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
def _get_actual_stage_num(self): def _get_actual_stage_num(self) -> int:
return self.stage_num if self.chunk == 1 else self.virtual_stage_num return self.stage_num if self.chunk == 1 else self.virtual_stage_num
def _create_pp_rank_to_rpc_worker_id(self): def _create_pp_rank_to_rpc_worker_id(self) -> None:
"""create a map from model partition to stage_id, which is useful when use_interleave is True. """create a map from model partition to stage_id, which is useful when use_interleave is True.
e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3. e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3.
stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
of partitions will be moved to device 0 and the others to device 1 of partitions will be moved to device 0 and the others to device 1
""" """
stage_num = self.stage_num stage_num = self.stage_num
actual_stage_num = self._get_actual_stage_num() actual_stage_num = self._get_actual_stage_num()
@ -628,28 +629,39 @@ class PipelineEngineBase(ABC, nn.Module):
for pp_rank in range(actual_stage_num): for pp_rank in range(actual_stage_num):
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num
def _init_worker(self): def _create_pp_rank_to_module_partition_id(self) -> None:
"""By default(both fill drain and 1F1B), length of model partitions equal to
actual_stage_num, so allocate model partition to corresponding stage
"""
actual_stage_num = self._get_actual_stage_num()
self.pp_rank_to_module_partition_id = [0] * actual_stage_num
for pp_rank in range(actual_stage_num):
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
def _init_worker(self) -> None:
actual_stage_num = self._get_actual_stage_num() actual_stage_num = self._get_actual_stage_num()
use_1F1B = self.use_1F1B worker_type = self.worker_type
checkpoint = self.checkpoint checkpoint = self.checkpoint
num_microbatches = self.num_microbatches num_microbatches = self.num_microbatches
device = self.device device = self.device
criterion = self.criterion criterion = self.criterion
metric = self.metric
for pp_rank in range(actual_stage_num): for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
module_partition = self.module_partitions[pp_rank] module_partition_id = self.pp_rank_to_module_partition_id[pp_rank]
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
if device[:4] == 'cuda': if device[:4] == 'cuda':
device = f'cuda:{rpc_worker_id}' device = f'cuda:{rpc_worker_id}'
module_partition = self.module_partitions[module_partition_id]
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
Worker, worker_type,
args=(module_partition, pp_rank, actual_stage_num, args=(module_partition, pp_rank, actual_stage_num,
num_microbatches, use_1F1B, device, criterion, num_microbatches, device, criterion, metric,
checkpoint)) checkpoint))
# let each worker know global worker rref (include itself) # let each worker know global worker rref (include itself)
for pp_rank in range(actual_stage_num): for pp_rank in self.pp_rank_to_worker_rref:
self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref) self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]: def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
@ -670,65 +682,110 @@ class PipelineEngineBase(ABC, nn.Module):
grads[stage_id].append(grad) grads[stage_id].append(grad)
return grads return grads
def get_input_pp_ranks(self) -> List[int]:
return [0]
def get_output_pp_ranks(self) -> List[int]:
return [self._get_actual_stage_num() - 1]
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
actual_stage_num = self._get_actual_stage_num()
use_1F1B = self.use_1F1B
if microbatch_id >= actual_stage_num:
if forward_only or not use_1F1B:
for pp_rank in output_pp_ranks:
ret_future[pp_rank][microbatch_id - actual_stage_num].wait()
else:
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().get_output_by_key(key)
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
return {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
# TODO : add relationship between input_pp_ranks and parts of microbatch
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
# TODO : add relationship between output_pp_ranks and parts of microlabels
worker_rref.remote().set_labels(microbatch_id, microlabels)
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD)
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
ret_future[pp_rank][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
if not forward_only:
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
worker_rref.rpc_sync().get_output_by_key(key)
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
forward_result = []
for pp_rank in output_pp_ranks:
worker_forward_result = [None] * self.num_microbatches
for microbatch_id in range(self.num_microbatches):
ret = ret_future[pp_rank][microbatch_id].wait()
worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result))
forward_result.extend(worker_forward_result)
return forward_result
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
if labels is not None: if labels is not None:
assert len(batch) == len(labels) assert len(batch) == len(labels)
if not forward_only:
assert hasattr(self, 'optimizer_class')
num_microbatches = self.num_microbatches num_microbatches = self.num_microbatches
microbatch_size = len(batch) // num_microbatches microbatch_size = len(batch) // num_microbatches
actual_stage_num = self._get_actual_stage_num()
first_worker_rref = self.pp_rank_to_worker_rref[0] # If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1] input_pp_ranks = self.get_input_pp_ranks()
output_pp_ranks = self.get_output_pp_ranks()
microbatch_iter = range(num_microbatches) # a cache to collect data and control flow
if use_progress: ret_future = self._create_ret_future(output_pp_ranks)
microbatch_iter = tqdm(microbatch_iter)
ret_future: List[Future] = [None] * num_microbatches for microbatch_id in range(num_microbatches):
for microbatch_id in microbatch_iter: # control data input speed
# control data input speed
# to prevent exceed of wait limitations # to prevent exceed of wait limitations
if microbatch_id >= actual_stage_num: self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
if forward_only or not self.use_1F1B:
ret_future[microbatch_id - actual_stage_num].wait()
else:
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
first_worker_rref.rpc_sync().get_output_by_key(key)
# set input # set input
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microbatch = microbatch.cuda() microbatch = microbatch.cuda()
first_worker_rref.remote().set_input(microbatch_id, microbatch, forward_only) self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only)
# set labels # set labels
if not forward_only and labels is not None: if labels is not None:
microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microlabels = microlabels.cuda() microlabels = microlabels.cuda()
last_worker_rref.remote().set_labels(microbatch_id, microlabels) self._set_labels(output_pp_ranks, microbatch_id, microlabels)
key = UniqueKey(microbatch_id, Phase.FORWARD) # get data asynchronously
ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key) self._subscribe_forward(microbatch_id, output_pp_ranks, ret_future)
# wait for last backward in rank0 # wait for first rank to ensure all backwards are done
if not forward_only: self._ensure_backward(forward_only, input_pp_ranks)
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
first_worker_rref.rpc_sync().get_output_by_key(key)
# collect forward result # collect forward result
# TODO : all the node to output forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
forward_result = None
for microbatch_id in range(self.num_microbatches): if not forward_only and labels is not None:
key = UniqueKey(microbatch_id, Phase.FORWARD)
ret = ret_future[microbatch_id].wait()
if forward_result is None:
forward_result = [[]] * len(ret)
for i in range(len(forward_result)):
forward_result[i].append(ret[i])
if hasattr(self, 'optimizer_class'):
# wait for all step # wait for all step
# TODO : more elegant ?
for pp_rank in self.pp_rank_to_worker_rref: for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank] worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().wait_for_step() worker_rref.rpc_sync().wait_for_step()
@ -751,31 +808,3 @@ class PipelineEngineBase(ABC, nn.Module):
for fut in self.step_futs: for fut in self.step_futs:
fut.wait() fut.wait()
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
checkpoint: bool = False) -> None:
use_1F1B = False
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
checkpoint: bool = False) -> None:
use_1F1B = True
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)

View File

@ -0,0 +1,277 @@
from typing import List, Callable, Dict
import torch.nn as nn
from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
# <strategy>PipelineEngine is the class for use
class FillDrainWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
num_microbatches = self.num_microbatches
if self.forward_times < num_microbatches:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
target_key = UniqueKey(target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
class FillDrainPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
use_1F1B = False
super().__init__(FillDrainWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint)
class OneFOneBWorker(WorkerBase):
def _get_work_item_key(self) -> UniqueKey:
# execute backward first (if backward phase in work_list)
pp_rank = self.pp_rank
actual_stage_num = self.actual_stage_num
num_microbatches = self.num_microbatches
is_last_stage = pp_rank == actual_stage_num - 1
if self.outstanding <= self.outstanding_range[0]:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
elif self.outstanding >= self.outstanding_range[1]:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
target_key = UniqueKey(target_microbatch_id, target_phase)
# change outstanding_range at:
# 1. forward times reach actual_stage_num, this is the end of continuous forward
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1:
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
class OneFOneBPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions: List[nn.Module],
stage_num: int,
num_microbatches: int,
device: str,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
use_1F1B = True
super().__init__(OneFOneBWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint)
class ChimeraWorker(WorkerBase):
def _get_producer_consumer(self) -> None:
rank = self.pp_rank
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
max_pp_rank = min_pp_rank + self.actual_stage_num - 1
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
# should be aranged in order, the order of the input of current forward
self.producer_stage_ids = []
self.consumer_stage_ids = []
# Just for demo
prev_rank = rank - 1
next_rank = rank + 1
if prev_rank >= min_pp_rank:
self.producer_stage_ids.append(prev_rank)
if next_rank <= max_pp_rank:
self.consumer_stage_ids.append(next_rank)
def _get_work_item_key(self) -> UniqueKey:
pp_rank = self.pp_rank
stage_num = self.actual_stage_num
real_microbatch_num = self.num_microbatches // 2
if self.forward_times < real_microbatch_num:
if (pp_rank + 1) % stage_num == 0: # last rank
forward_blocks = self.forward_times // (self.num_microbatches // stage_num)
if forward_blocks > self.backward_times:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
else:
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else: # others
target_phase = Phase.FORWARD
target_microbatch_id = self.forward_times
else:
target_phase = Phase.BACKWARD
target_microbatch_id = self.backward_times
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
real_target_microbatch_id = target_microbatch_id * 2
if pp_rank >= stage_num:
real_target_microbatch_id += 1
target_key = UniqueKey(real_target_microbatch_id, target_phase)
with self.work_list_condition_lock:
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
return target_key
def is_first_stage(self):
return (self.pp_rank % self.actual_stage_num) == 0
def is_last_stage(self):
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
class ChimeraPipelineEngine(PipelineEngineBase):
def __init__(self,
module_partitions,
stage_num,
num_microbatches,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
assert num_microbatches % stage_num == 0, \
"In Chimera, num_microbatches must be the multiply of stage_num!"
use_1F1B = False
chunk = 1
super().__init__(ChimeraWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
criterion, metric, checkpoint)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
input_worker_rrefs: List[PyRRef], output_worker_rrefs: List[PyRRef]):
pass
def _create_pp_rank_to_rpc_worker_id(self) -> None:
stage_num = self.stage_num
self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2)
for pp_rank in range(stage_num):
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank
self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1
def _create_pp_rank_to_module_partition_id(self) -> None:
stage_num = self.stage_num
self.pp_rank_to_module_partition_id = [0] * (stage_num * 2)
for pp_rank in range(stage_num):
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
num_microbatches = self.num_microbatches
stage_num = self.stage_num
up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks}
# merge up and down
return {**up_ret_future, **down_ret_future}
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
# offset is 0 for all the ranks in up pipeline
# offset is stage_num for all the ranks in down pipeline
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in input_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
# offset is 0 for all the ranks in up pipeline
# offset is stage_num for all the ranks in down pipeline
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
worker_rref.remote().set_labels(microbatch_id, microlabels)
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
key = UniqueKey(microbatch_id, Phase.FORWARD)
offset = (microbatch_id % 2) * self.stage_num
for pp_rank in output_pp_ranks:
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
stage_num = self.stage_num
num_microbatches = self.num_microbatches
if not forward_only:
for pp_rank in input_pp_ranks:
up_last_microbatch_id = num_microbatches - 2
down_last_microbatch_id = num_microbatches - 1
up_worker_rref = self.pp_rank_to_worker_rref[pp_rank]
down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num]
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
up_worker_rref.rpc_sync().get_output_by_key(up_key)
down_worker_rref.rpc_sync().get_output_by_key(down_key)
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]):
"""Logic of collection of forward in Chimera.
Currently, only one input one output model is supported
"""
stage_num = self.stage_num
forward_result = []
for pp_rank in output_pp_ranks:
worker_forward_result = [None] * self.num_microbatches
for microbatch_id in range(self.num_microbatches):
offset = (microbatch_id % 2) * stage_num
ret = ret_future[pp_rank + offset][microbatch_id].wait()
worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result))
forward_result.extend(worker_forward_result)
return forward_result

BIN
data/cifar-10-python.tar.gz Normal file

Binary file not shown.

View File

@ -0,0 +1,43 @@
import torch
from torch import nn
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
def run_master(args):
torch.manual_seed(100)
epoch = args.epoch
device = args.device
stage_num = 4
chunk = 1
num_microbatches = 4
actual_stage_num = 4
use_checkpoint = False
sample_num = 1024
feat_num = 10
h = 10
batch_size = 1024
assert sample_num % batch_size == 0
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
engine = ChimeraPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
checkpoint=use_checkpoint)
input_sample = torch.randn((sample_num, feat_num), device=device)
for _ in range(epoch):
_ = engine.forward_backward(input_sample, forward_only=False)
if __name__ == "__main__":
args = parse_args()
args.world_size = 4
args.num_microbatches = 4
rpc_run(args, run_master)

View File

@ -3,7 +3,7 @@ from torch import nn
from torch import autograd from torch import autograd
from torch.optim import SGD, Adam, RMSprop, Optimizer from torch.optim import SGD, Adam, RMSprop, Optimizer
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.testing import assert_close from colossalai.testing import assert_close
from rpc_test_utils import rpc_run, parse_args, RpcTestModel from rpc_test_utils import rpc_run, parse_args, RpcTestModel

View File

@ -0,0 +1,102 @@
import os
from typing import Callable, List, Optional, Type, Union
import time
import pytest
import torch
import torch.nn as nn
from titans.dataloader.cifar10 import build_cifar
from torchvision.models import resnet50
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
from tqdm import tqdm
from rpc_test_utils import rpc_run, parse_args
import colossalai
import colossalai.nn as col_nn
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.context import ParallelMode
from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
def flatten(x):
return torch.flatten(x, 1)
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, start_dim=1)
def run_master(args):
batch_size = args.batch_size
chunk = args.chunk
device = args.device
world_size = args.world_size
stage_num = world_size
num_microbatches = args.num_microbatches
assert chunk == 1
pipelinable = PipelinableContext()
# build model partitions
with pipelinable:
# input : [B, 3, 32, 32]
model = resnet50()
exec_seq = [
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
]
pipelinable.to_layer_list(exec_seq)
module_partitions: List[PipelinableModel] = [
pipelinable.partition(chunk, stage_num, pp_rank) for pp_rank in range(world_size)
]
# build dataloader
root = os.environ.get('DATA', './data')
train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
criterion = nn.CrossEntropyLoss()
partition_1 = module_partitions[0]
partition_2 = []
for model in module_partitions[1]._module_list:
partition_2.append(model)
partition_2.insert(len(partition_2) - 1, Flatten())
partition_2 = nn.Sequential(*partition_2)
module_partitions = [partition_1, partition_2]
pp_engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
criterion=criterion,
checkpoint=False)
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
s = time.time()
for bx, by in tqdm(train_dataloader):
pp_engine.forward_backward(bx, labels=by, forward_only=False)
cost_time = time.time() - s
print("total cost time :", cost_time)
print("cost time per batch:", cost_time / len(train_dataloader))
@pytest.mark.skip("Test for performance, no need for CI")
def main():
args = parse_args()
# this is due to limitation of partition function
args.world_size = 2
args.chunk = 1
rpc_run(args, run_master)
if __name__ == '__main__':
main()

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import nn from torch import nn
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from rpc_test_utils import rpc_run, parse_args, RpcTestModel from rpc_test_utils import rpc_run, parse_args, RpcTestModel

View File

@ -2,7 +2,7 @@ import torch
from torch import nn from torch import nn
from torch import autograd from torch import autograd
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.testing import assert_close from colossalai.testing import assert_close
from rpc_test_utils import rpc_run, parse_args, RpcTestModel from rpc_test_utils import rpc_run, parse_args, RpcTestModel
@ -36,7 +36,7 @@ def run_master(args):
chunk=chunk, chunk=chunk,
checkpoint=use_checkpoint) checkpoint=use_checkpoint)
forward_result = engine.forward_backward(input_sample) forward_result = engine.forward_backward(input_sample)[0]
cuda_rpc_result = [] cuda_rpc_result = []
single_result = [] single_result = []