mirror of https://github.com/hpcaitech/ColossalAI
343 lines
14 KiB
Python
343 lines
14 KiB
Python
from typing import List, Callable, Dict
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.futures import Future
|
|
from torch._C._distributed_rpc import PyRRef
|
|
|
|
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase, WorkItem
|
|
from colossalai.pipeline.pipeline_process_group import ppg
|
|
|
|
# Implementation of different Pipeline schedule
|
|
# <strategy>Worker defines the worker for each stage
|
|
# <strategy>PipelineEngine is the class for use
|
|
|
|
|
|
class FillDrainWorker(WorkerBase):
|
|
|
|
def _get_work_item_key(self) -> UniqueKey:
|
|
# execute backward first (if backward phase in work_list)
|
|
num_microbatches = self.num_microbatches
|
|
|
|
if self.forward_times < num_microbatches:
|
|
target_phase = Phase.FORWARD
|
|
target_microbatch_id = self.forward_times
|
|
else:
|
|
target_phase = Phase.BACKWARD
|
|
target_microbatch_id = self.backward_times
|
|
|
|
target_key = UniqueKey(target_microbatch_id, target_phase)
|
|
|
|
with self.work_list_condition_lock:
|
|
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
|
|
|
return target_key
|
|
|
|
|
|
class FillDrainPipelineEngine(PipelineEngineBase):
|
|
|
|
def __init__(self,
|
|
partition_fn: Callable,
|
|
stage_num: int,
|
|
num_microbatches: int,
|
|
device: str,
|
|
chunk: int = 1,
|
|
criterion: Callable = None,
|
|
metric: Callable = None,
|
|
checkpoint: bool = False,
|
|
data_process_func: Callable = None) -> None:
|
|
|
|
if chunk > 1:
|
|
assert num_microbatches % stage_num == 0, \
|
|
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
|
use_1F1B = False
|
|
|
|
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
|
metric, checkpoint, data_process_func)
|
|
|
|
|
|
class OneFOneBWorker(WorkerBase):
|
|
|
|
def _get_work_item_key(self) -> UniqueKey:
|
|
# execute backward first (if backward phase in work_list)
|
|
pp_rank = self.pp_rank
|
|
actual_stage_num = self.actual_stage_num
|
|
num_microbatches = self.num_microbatches
|
|
is_last_stage = pp_rank == actual_stage_num - 1
|
|
|
|
if self.outstanding <= self.outstanding_range[0]:
|
|
target_phase = Phase.FORWARD
|
|
target_microbatch_id = self.forward_times
|
|
elif self.outstanding >= self.outstanding_range[1]:
|
|
target_phase = Phase.BACKWARD
|
|
target_microbatch_id = self.backward_times
|
|
else:
|
|
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
|
|
|
|
target_key = UniqueKey(target_microbatch_id, target_phase)
|
|
|
|
# change outstanding_range at:
|
|
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
|
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
|
if not is_last_stage and \
|
|
target_key.phase == Phase.FORWARD:
|
|
if target_key.microbatch_id == actual_stage_num - 1:
|
|
outstanding_min = actual_stage_num - pp_rank - 1
|
|
outstanding_max = actual_stage_num - pp_rank
|
|
self.outstanding_range = (outstanding_min, outstanding_max)
|
|
elif target_key.microbatch_id == num_microbatches - 1:
|
|
self.outstanding_range = (0, 0)
|
|
|
|
with self.work_list_condition_lock:
|
|
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
|
|
|
return target_key
|
|
|
|
|
|
class OneFOneBPipelineEngine(PipelineEngineBase):
|
|
|
|
def __init__(self,
|
|
partition_fn: Callable,
|
|
stage_num: int,
|
|
num_microbatches: int,
|
|
device: str,
|
|
chunk: int = 1,
|
|
criterion: Callable = None,
|
|
metric: Callable = None,
|
|
checkpoint: bool = False,
|
|
data_process_func: Callable = None) -> None:
|
|
|
|
if chunk > 1:
|
|
assert num_microbatches % stage_num == 0, \
|
|
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
|
# assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk"
|
|
use_1F1B = True
|
|
|
|
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
|
metric, checkpoint, data_process_func)
|
|
|
|
|
|
class ChimeraWorker(WorkerBase):
|
|
|
|
def _get_producer_consumer(self) -> None:
|
|
rank = self.pp_rank
|
|
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
|
|
max_pp_rank = min_pp_rank + self.actual_stage_num - 1
|
|
|
|
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
|
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
|
|
|
# should be aranged in order, the order of the input of current forward
|
|
self.producer_stage_ids = []
|
|
self.consumer_stage_ids = []
|
|
|
|
# Just for demo
|
|
prev_rank = rank - 1
|
|
next_rank = rank + 1
|
|
if prev_rank >= min_pp_rank:
|
|
self.producer_stage_ids.append(prev_rank)
|
|
if next_rank <= max_pp_rank:
|
|
self.consumer_stage_ids.append(next_rank)
|
|
|
|
def _get_work_item_key(self) -> UniqueKey:
|
|
pp_rank = self.pp_rank
|
|
stage_num = self.actual_stage_num
|
|
real_microbatch_num = self.num_microbatches // 2
|
|
|
|
forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num
|
|
forward_block_num = self.forward_times // forward_block_size
|
|
|
|
if self.forward_times >= real_microbatch_num or \
|
|
((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times):
|
|
target_phase = Phase.BACKWARD
|
|
target_microbatch_id = self.backward_times
|
|
else: # others
|
|
target_phase = Phase.FORWARD
|
|
target_microbatch_id = self.forward_times
|
|
|
|
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
|
|
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
|
|
real_target_microbatch_id = target_microbatch_id * 2
|
|
if pp_rank >= stage_num:
|
|
real_target_microbatch_id += 1
|
|
target_key = UniqueKey(real_target_microbatch_id, target_phase)
|
|
|
|
with self.work_list_condition_lock:
|
|
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
|
return target_key
|
|
|
|
def _initialize_partition(self):
|
|
# In order to ensure the down pipeline share the same parameter
|
|
# with the up pipeline, partition of down partition will be copied
|
|
# from corresponding up stage
|
|
pp_rank = self.pp_rank
|
|
stage_num = self.actual_stage_num
|
|
device = self.device
|
|
if pp_rank < stage_num:
|
|
super()._initialize_partition()
|
|
else:
|
|
# if it is down pipeline, create partition by origin method
|
|
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
|
|
# get the coresponding model state dict and wait for its init
|
|
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
|
|
super()._initialize_partition()
|
|
self.module_partition.load_state_dict(state_dict)
|
|
|
|
# init group for chimera in ppg
|
|
ppg.get_chimera_all_reduce_group(pp_rank)
|
|
|
|
def is_first_stage(self):
|
|
return (self.pp_rank % self.actual_stage_num) == 0
|
|
|
|
def is_last_stage(self):
|
|
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
|
|
|
|
def _is_last_step(self, work_item: WorkItem) -> bool:
|
|
if work_item.forward_only:
|
|
last_phase = Phase.FORWARD
|
|
else:
|
|
last_phase = Phase.BACKWARD
|
|
is_last_phase = work_item.phase == last_phase
|
|
last_microbatch_id = self.num_microbatches - 1
|
|
if self.pp_rank < self.actual_stage_num:
|
|
last_microbatch_id -= 1
|
|
is_last_microbatch = work_item.microbatch_id == last_microbatch_id
|
|
return is_last_phase and is_last_microbatch
|
|
|
|
def _get_step_order(self) -> List[int]:
|
|
# TODO : If you want to extend it to multi head chimera, overwrite here
|
|
stage_num = self.actual_stage_num
|
|
pp_rank = self.pp_rank
|
|
# pp_rank in the same device
|
|
local_device_pp_ranks = [pp_rank, stage_num * 2 - pp_rank - 1]
|
|
local_device_pp_ranks.sort(reverse=min(local_device_pp_ranks) < stage_num // 2)
|
|
return local_device_pp_ranks
|
|
|
|
def _hook_before_step(self):
|
|
pp_rank = self.pp_rank
|
|
|
|
orders = self._get_step_order()
|
|
step_index = orders.index(pp_rank)
|
|
|
|
# if currrent pp_rank is not the first to do step
|
|
# wait its previous pp_rank finish step
|
|
|
|
all_reduce_group = ppg.get_chimera_all_reduce_group(self.pp_rank)
|
|
grads = self.get_parameter_gradients()
|
|
|
|
# print(self.pp_rank, "begin all reduce", torch.cuda.max_memory_allocated(ppg.get_local_pp_rank()), torch.cuda.max_memory_reserved(ppg.get_local_pp_rank()))
|
|
if step_index == 1:
|
|
ppg.chimera_step_lock.acquire()
|
|
|
|
# print(f'rank_{self.pp_rank} before all reduce')
|
|
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
|
|
# print(f'rank_{self.pp_rank} after all reduce')
|
|
|
|
if step_index == 0:
|
|
ppg.chimera_step_lock.release()
|
|
|
|
|
|
class ChimeraPipelineEngine(PipelineEngineBase):
|
|
|
|
def __init__(self,
|
|
partition_fn: Callable,
|
|
stage_num: int,
|
|
num_microbatches: int,
|
|
device: str,
|
|
criterion: Callable = None,
|
|
metric: Callable = None,
|
|
checkpoint: bool = False,
|
|
data_process_func: Callable = None) -> None:
|
|
|
|
assert num_microbatches % stage_num == 0, \
|
|
"In Chimera, num_microbatches must be the multiply of stage_num!"
|
|
use_1F1B = False
|
|
chunk = 1
|
|
|
|
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
|
metric, checkpoint, data_process_func)
|
|
|
|
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
|
|
input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
|
|
pass
|
|
|
|
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
|
stage_num = self.stage_num
|
|
self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2)
|
|
for pp_rank in range(stage_num):
|
|
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank
|
|
self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1
|
|
|
|
def _create_pp_rank_to_module_partition_id(self) -> None:
|
|
stage_num = self.stage_num
|
|
self.pp_rank_to_module_partition_id = [0] * (stage_num * 2)
|
|
for pp_rank in range(stage_num):
|
|
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
|
|
self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank
|
|
|
|
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
|
num_microbatches = self.num_microbatches
|
|
stage_num = self.stage_num
|
|
up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
|
|
down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks}
|
|
# merge up and down
|
|
return {**up_ret_future, **down_ret_future}
|
|
|
|
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
|
|
# offset is 0 for all the ranks in up pipeline
|
|
# offset is stage_num for all the ranks in down pipeline
|
|
offset = (microbatch_id % 2) * self.stage_num
|
|
for pp_rank in input_pp_ranks:
|
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
|
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
|
|
|
|
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
|
|
# offset is 0 for all the ranks in up pipeline
|
|
# offset is stage_num for all the ranks in down pipeline
|
|
offset = (microbatch_id % 2) * self.stage_num
|
|
for pp_rank in output_pp_ranks:
|
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
|
worker_rref.remote().set_labels(microbatch_id, microlabels)
|
|
|
|
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
|
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
|
offset = (microbatch_id % 2) * self.stage_num
|
|
for pp_rank in output_pp_ranks:
|
|
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
|
ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
|
|
|
|
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
|
|
stage_num = self.stage_num
|
|
num_microbatches = self.num_microbatches
|
|
if not forward_only:
|
|
for pp_rank in input_pp_ranks:
|
|
up_last_microbatch_id = num_microbatches - 2
|
|
down_last_microbatch_id = num_microbatches - 1
|
|
|
|
up_worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
|
down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num]
|
|
|
|
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
|
|
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
|
|
up_worker_rref.rpc_sync().get_output_by_key(up_key)
|
|
down_worker_rref.rpc_sync().get_output_by_key(down_key)
|
|
|
|
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]):
|
|
"""Logic of collection of forward in Chimera.
|
|
Currently, only one input one output model is supported
|
|
"""
|
|
stage_num = self.stage_num
|
|
forward_result = []
|
|
for pp_rank in output_pp_ranks:
|
|
worker_forward_result = [None] * self.num_microbatches
|
|
for microbatch_id in range(self.num_microbatches):
|
|
offset = (microbatch_id % 2) * stage_num
|
|
ret = ret_future[pp_rank + offset][microbatch_id].wait()
|
|
ret = [ret] if isinstance(ret, torch.Tensor) else ret
|
|
worker_forward_result[microbatch_id] = ret
|
|
|
|
worker_forward_result = list(zip(*worker_forward_result))
|
|
forward_result.extend(worker_forward_result)
|
|
|
|
return forward_result
|