mirror of https://github.com/hpcaitech/ColossalAI
[pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera (#1595)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimerapull/1609/head
parent
c9e8ce67b8
commit
edc9e419ad
|
@ -1,8 +1,9 @@
|
|||
import threading
|
||||
from enum import Enum
|
||||
from typing import List, Any, Tuple, Dict, Callable
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
import sys
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -17,12 +18,10 @@ from time import time
|
|||
from colorama import Back, Style
|
||||
|
||||
# config for debug and test
|
||||
use_color_debug = False
|
||||
use_progress = False
|
||||
use_color_debug = True
|
||||
|
||||
# TODO:
|
||||
# 1. replace world_size with other parameters
|
||||
# 2. adjust to args and kwargs
|
||||
# 1. adjust to args and kwargs (pytree)
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
|
@ -137,24 +136,24 @@ class BackwardCache:
|
|||
setattr(self, arg_name, locals()[arg_name])
|
||||
|
||||
|
||||
class Worker:
|
||||
class WorkerBase(ABC):
|
||||
|
||||
def __init__(self,
|
||||
module_partition: nn.Module,
|
||||
pp_rank: int,
|
||||
actual_stage_num: int,
|
||||
num_microbatches: int,
|
||||
use_1F1B: bool,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pp_rank = pp_rank
|
||||
self.actual_stage_num = actual_stage_num
|
||||
self.num_microbatches = num_microbatches
|
||||
self.checkpoint = checkpoint
|
||||
self.device = device
|
||||
self.use_1F1B = use_1F1B
|
||||
self._initialize_outstanding_range()
|
||||
|
||||
# variable and const for context managment
|
||||
|
@ -172,23 +171,14 @@ class Worker:
|
|||
|
||||
# module partitions
|
||||
self.module_partition = module_partition.to(device)
|
||||
if criterion:
|
||||
assert callable(criterion)
|
||||
self.criterion = criterion
|
||||
self.metric = metric
|
||||
|
||||
# container to maintain loop
|
||||
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
|
||||
self.microbatch_id_to_labels: Dict[int, Any] = dict()
|
||||
self.work_list: Dict[UniqueKey, WorkItem] = dict()
|
||||
self.output_list: Dict[UniqueKey, WorkItem] = dict()
|
||||
# context to maintain loop
|
||||
self._initialize_context_container()
|
||||
|
||||
# lock for the list
|
||||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.label_lock = threading.Condition(threading.Lock())
|
||||
|
||||
self.step_lock = threading.Lock()
|
||||
self.step_lock.acquire()
|
||||
self._initialize_lock()
|
||||
|
||||
# main loop
|
||||
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
|
||||
|
@ -199,13 +189,23 @@ class Worker:
|
|||
|
||||
def _initialize_outstanding_range(self):
|
||||
outstanding_range = None
|
||||
if self.use_1F1B:
|
||||
if self.pp_rank == self.actual_stage_num - 1:
|
||||
outstanding_range = (0, 1)
|
||||
else:
|
||||
outstanding_range = (self.actual_stage_num, self.actual_stage_num)
|
||||
if self.pp_rank == self.actual_stage_num - 1:
|
||||
outstanding_range = (0, 1)
|
||||
else:
|
||||
outstanding_range = (self.actual_stage_num, self.actual_stage_num)
|
||||
self.outstanding_range = outstanding_range
|
||||
|
||||
def _initialize_context_container(self):
|
||||
self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict()
|
||||
self.microbatch_id_to_labels: Dict[int, Any] = dict()
|
||||
self.work_list: Dict[UniqueKey, WorkItem] = dict()
|
||||
self.output_list: Dict[UniqueKey, WorkItem] = dict()
|
||||
|
||||
def _initialize_lock(self):
|
||||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.label_lock = threading.Condition(threading.Lock())
|
||||
|
||||
def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> None:
|
||||
assert self.pp_rank_to_worker_rref is None, f"in rank {self.pp_rank}, worker has sync global workers rrefs"
|
||||
assert pp_rank_to_worker_rref is not None, "stage_to_workers must be a dict instead of None"
|
||||
|
@ -241,12 +241,15 @@ class Worker:
|
|||
forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
color_debug(f'rank {self.pp_rank} receive data from dataloader', 'data dispatch', 'magenta')
|
||||
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', 'data dispatch',
|
||||
'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# just for last pp_rank
|
||||
def set_labels(self, microbatch_id: int, microlabels: Any):
|
||||
self.microbatch_id_to_labels[microbatch_id] = microlabels
|
||||
with self.label_lock:
|
||||
self.microbatch_id_to_labels[microbatch_id] = microlabels
|
||||
self.label_lock.notify_all()
|
||||
|
||||
# just for last pp_rank
|
||||
def _begin_backward(self, microbatch_id: int):
|
||||
|
@ -354,51 +357,17 @@ class Worker:
|
|||
if next_rank <= self.actual_stage_num - 1:
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
|
||||
@abstractmethod
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
# execute backward first (if backward phase in work_list)
|
||||
pp_rank = self.pp_rank
|
||||
actual_stage_num = self.actual_stage_num
|
||||
num_microbatches = self.num_microbatches
|
||||
is_last_stage = pp_rank == actual_stage_num - 1
|
||||
"""
|
||||
this method control the order of the microbatch to consume
|
||||
"""
|
||||
|
||||
if self.outstanding_range:
|
||||
if self.outstanding <= self.outstanding_range[0]:
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
elif self.outstanding >= self.outstanding_range[1]:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
else:
|
||||
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
|
||||
def is_first_stage(self):
|
||||
return self.pp_rank == 0
|
||||
|
||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||
|
||||
# change outstanding_range at:
|
||||
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
||||
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
||||
if not is_last_stage and \
|
||||
target_key.phase == Phase.FORWARD:
|
||||
if target_key.microbatch_id == actual_stage_num - 1:
|
||||
outstanding_min = actual_stage_num - pp_rank - 1
|
||||
outstanding_max = actual_stage_num - pp_rank
|
||||
self.outstanding_range = (outstanding_min, outstanding_max)
|
||||
elif target_key.microbatch_id == num_microbatches - 1:
|
||||
self.outstanding_range = (0, 0)
|
||||
|
||||
else:
|
||||
if self.forward_times < num_microbatches:
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
else:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
|
||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
||||
|
||||
return target_key
|
||||
def is_last_stage(self):
|
||||
return self.pp_rank == self.actual_stage_num - 1
|
||||
|
||||
def _consume_work_item_by_phase(self, work_item: WorkItem):
|
||||
phase = work_item.phase
|
||||
|
@ -408,9 +377,8 @@ class Worker:
|
|||
forward_only = work_item.forward_only
|
||||
consume_result = None
|
||||
|
||||
# TODO : use process manager to acquire rank info later
|
||||
is_first_stage = (self.pp_rank == 0)
|
||||
is_last_stage = (self.pp_rank == self.actual_stage_num - 1)
|
||||
is_first_stage = self.is_first_stage()
|
||||
is_last_stage = self.is_last_stage()
|
||||
|
||||
# if self.pp_rank == 0:
|
||||
# print(
|
||||
|
@ -433,6 +401,21 @@ class Worker:
|
|||
if forward_only:
|
||||
with torch.no_grad():
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
|
||||
# TODO : integrate output list
|
||||
if is_last_stage and self.criterion:
|
||||
with self.label_lock:
|
||||
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
|
||||
labels = self.microbatch_id_to_labels.pop(microbatch_id)
|
||||
loss: torch.Tensor = self.criterion(consume_result, labels)
|
||||
if self.metric is not None:
|
||||
metric_result = self.metric(consume_result, labels)
|
||||
if isinstance(metric_result, torch.Tensor):
|
||||
metric_result = metric_result.item()
|
||||
else:
|
||||
metric_result = None
|
||||
consume_result = [loss.item(), metric_result]
|
||||
|
||||
stage_outputs = None
|
||||
stage_inputs = None
|
||||
use_checkpoint = None
|
||||
|
@ -444,10 +427,21 @@ class Worker:
|
|||
use_checkpoint = True
|
||||
else:
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
|
||||
|
||||
if is_last_stage and self.criterion:
|
||||
with self.label_lock:
|
||||
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
|
||||
labels = self.microbatch_id_to_labels.pop(microbatch_id)
|
||||
loss: torch.Tensor = self.criterion(consume_result, labels)
|
||||
consume_result = loss.item()
|
||||
if self.metric is not None:
|
||||
metric_result = self.metric(consume_result, labels)
|
||||
if isinstance(metric_result, torch.Tensor):
|
||||
metric_result = metric_result.item()
|
||||
else:
|
||||
metric_result = None
|
||||
|
||||
consume_result = [loss.item(), metric_result]
|
||||
else:
|
||||
loss = consume_result
|
||||
|
||||
|
@ -486,6 +480,7 @@ class Worker:
|
|||
|
||||
if use_checkpoint:
|
||||
stage_outputs = [self.module_partition(*stage_inputs)]
|
||||
|
||||
# overlap recompute and future.wait
|
||||
grad_tensors = get_real_args(args)
|
||||
|
||||
|
@ -513,11 +508,17 @@ class Worker:
|
|||
grad_sum += p.grad.sum()
|
||||
return grad_sum
|
||||
|
||||
def _is_first_step(self, work_item) -> bool:
|
||||
def _is_first_step(self, work_item: WorkItem) -> bool:
|
||||
return work_item.phase == Phase.FORWARD and work_item.microbatch_id == 0
|
||||
|
||||
def _is_last_step(self, work_item) -> bool:
|
||||
return work_item.phase == Phase.BACKWARD and work_item.microbatch_id == self.num_microbatches - 1
|
||||
def _is_last_step(self, work_item: WorkItem) -> bool:
|
||||
if work_item.forward_only:
|
||||
last_phase = Phase.FORWARD
|
||||
else:
|
||||
last_phase = Phase.BACKWARD
|
||||
is_last_phase = work_item.phase == last_phase
|
||||
is_last_microbatch = work_item.microbatch_id == self.num_microbatches - 1
|
||||
return is_last_phase and is_last_microbatch
|
||||
|
||||
# do the main loop to consume ready_list
|
||||
def _work_loop(self):
|
||||
|
@ -551,7 +552,7 @@ class Worker:
|
|||
|
||||
# if is last step in one batch reset context and do step
|
||||
if self._is_last_step(work_item):
|
||||
if hasattr(self, 'optimizer'):
|
||||
if hasattr(self, 'optimizer') and not work_item.forward_only:
|
||||
self.step()
|
||||
self.forward_times = 0
|
||||
self.backward_times = 0
|
||||
|
@ -560,22 +561,22 @@ class Worker:
|
|||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||
self.step_lock = threading.Lock()
|
||||
self.step_lock.acquire()
|
||||
|
||||
def wait_for_step(self):
|
||||
self.step_lock.acquire()
|
||||
|
||||
def step(self):
|
||||
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
|
||||
self.optimizer.step()
|
||||
# print(f'rank_{self.pp_rank}', sum([p.sum() for p in self.module_partition.parameters()]))
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
self.step_lock.release()
|
||||
|
||||
|
||||
class PipelineEngineBase(ABC, nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
worker_type,
|
||||
module_partitions,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
|
@ -583,17 +584,19 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
use_1F1B=False,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.worker_type = worker_type
|
||||
self.module_partitions: List[nn.Module] = module_partitions
|
||||
self.chunk = chunk
|
||||
self.criterion = criterion
|
||||
self.metric = metric
|
||||
self.num_microbatches = num_microbatches
|
||||
self.device = device
|
||||
self.use_1F1B = use_1F1B
|
||||
self.stage_num = stage_num
|
||||
self.checkpoint = checkpoint
|
||||
self.use_interleave = chunk > 1
|
||||
|
||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
|
||||
|
||||
|
@ -601,26 +604,24 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
|
||||
self._check_argument()
|
||||
self._create_pp_rank_to_rpc_worker_id()
|
||||
self._create_pp_rank_to_module_partition_id()
|
||||
self._init_worker()
|
||||
|
||||
def _check_argument(self):
|
||||
def _check_argument(self) -> None:
|
||||
self.virtual_stage_num = self.stage_num * self.chunk
|
||||
|
||||
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
|
||||
assert self.virtual_stage_num == len(
|
||||
self.module_partitions), "stage_num * chunk must be equal to length of model partition!"
|
||||
if self.use_interleave:
|
||||
assert self.num_microbatches % self.stage_num == 0, "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
|
||||
def _get_actual_stage_num(self):
|
||||
def _get_actual_stage_num(self) -> int:
|
||||
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
|
||||
|
||||
def _create_pp_rank_to_rpc_worker_id(self):
|
||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||
"""create a map from model partition to stage_id, which is useful when use_interleave is True.
|
||||
e.g. If a model is splited into 4 parts, which means len(self.module_partitions) == 3.
|
||||
stage_num is 2, chunk is 2, then pp_rank_to_rpc_worker_id = [0, 1, 0, 1], that means first and third part
|
||||
of partitions will be moved to device 0 and the others to device 1
|
||||
|
||||
"""
|
||||
stage_num = self.stage_num
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
|
@ -628,28 +629,39 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
for pp_rank in range(actual_stage_num):
|
||||
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank % stage_num
|
||||
|
||||
def _init_worker(self):
|
||||
def _create_pp_rank_to_module_partition_id(self) -> None:
|
||||
"""By default(both fill drain and 1F1B), length of model partitions equal to
|
||||
actual_stage_num, so allocate model partition to corresponding stage
|
||||
"""
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
self.pp_rank_to_module_partition_id = [0] * actual_stage_num
|
||||
for pp_rank in range(actual_stage_num):
|
||||
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
|
||||
|
||||
def _init_worker(self) -> None:
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
|
||||
use_1F1B = self.use_1F1B
|
||||
worker_type = self.worker_type
|
||||
checkpoint = self.checkpoint
|
||||
num_microbatches = self.num_microbatches
|
||||
device = self.device
|
||||
criterion = self.criterion
|
||||
metric = self.metric
|
||||
|
||||
for pp_rank in range(actual_stage_num):
|
||||
module_partition = self.module_partitions[pp_rank]
|
||||
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
|
||||
module_partition_id = self.pp_rank_to_module_partition_id[pp_rank]
|
||||
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
|
||||
if device[:4] == 'cuda':
|
||||
device = f'cuda:{rpc_worker_id}'
|
||||
module_partition = self.module_partitions[module_partition_id]
|
||||
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
|
||||
Worker,
|
||||
worker_type,
|
||||
args=(module_partition, pp_rank, actual_stage_num,
|
||||
num_microbatches, use_1F1B, device, criterion,
|
||||
num_microbatches, device, criterion, metric,
|
||||
checkpoint))
|
||||
|
||||
# let each worker know global worker rref (include itself)
|
||||
for pp_rank in range(actual_stage_num):
|
||||
for pp_rank in self.pp_rank_to_worker_rref:
|
||||
self.pp_rank_to_worker_rref[pp_rank].rpc_sync().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
|
||||
|
||||
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
|
||||
|
@ -670,65 +682,110 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
grads[stage_id].append(grad)
|
||||
return grads
|
||||
|
||||
def get_input_pp_ranks(self) -> List[int]:
|
||||
return [0]
|
||||
|
||||
def get_output_pp_ranks(self) -> List[int]:
|
||||
return [self._get_actual_stage_num() - 1]
|
||||
|
||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
|
||||
output_pp_ranks: List[int], ret_future):
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
use_1F1B = self.use_1F1B
|
||||
if microbatch_id >= actual_stage_num:
|
||||
if forward_only or not use_1F1B:
|
||||
for pp_rank in output_pp_ranks:
|
||||
ret_future[pp_rank][microbatch_id - actual_stage_num].wait()
|
||||
else:
|
||||
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
||||
for pp_rank in input_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_sync().get_output_by_key(key)
|
||||
|
||||
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
||||
num_microbatches = self.num_microbatches
|
||||
return {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
|
||||
|
||||
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
|
||||
for pp_rank in input_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
# TODO : add relationship between input_pp_ranks and parts of microbatch
|
||||
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
|
||||
|
||||
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
|
||||
for pp_rank in output_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
# TODO : add relationship between output_pp_ranks and parts of microlabels
|
||||
worker_rref.remote().set_labels(microbatch_id, microlabels)
|
||||
|
||||
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
for pp_rank in output_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
ret_future[pp_rank][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
|
||||
|
||||
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
|
||||
if not forward_only:
|
||||
for pp_rank in input_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||
worker_rref.rpc_sync().get_output_by_key(key)
|
||||
|
||||
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
||||
forward_result = []
|
||||
for pp_rank in output_pp_ranks:
|
||||
worker_forward_result = [None] * self.num_microbatches
|
||||
for microbatch_id in range(self.num_microbatches):
|
||||
ret = ret_future[pp_rank][microbatch_id].wait()
|
||||
worker_forward_result[microbatch_id] = ret
|
||||
worker_forward_result = list(zip(*worker_forward_result))
|
||||
forward_result.extend(worker_forward_result)
|
||||
|
||||
return forward_result
|
||||
|
||||
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
||||
if labels is not None:
|
||||
assert len(batch) == len(labels)
|
||||
if not forward_only:
|
||||
assert hasattr(self, 'optimizer_class')
|
||||
|
||||
num_microbatches = self.num_microbatches
|
||||
microbatch_size = len(batch) // num_microbatches
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
|
||||
first_worker_rref = self.pp_rank_to_worker_rref[0]
|
||||
last_worker_rref = self.pp_rank_to_worker_rref[actual_stage_num - 1]
|
||||
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
|
||||
input_pp_ranks = self.get_input_pp_ranks()
|
||||
output_pp_ranks = self.get_output_pp_ranks()
|
||||
|
||||
microbatch_iter = range(num_microbatches)
|
||||
if use_progress:
|
||||
microbatch_iter = tqdm(microbatch_iter)
|
||||
# a cache to collect data and control flow
|
||||
ret_future = self._create_ret_future(output_pp_ranks)
|
||||
|
||||
ret_future: List[Future] = [None] * num_microbatches
|
||||
for microbatch_id in microbatch_iter:
|
||||
# control data input speed
|
||||
for microbatch_id in range(num_microbatches):
|
||||
# control data input speed
|
||||
# to prevent exceed of wait limitations
|
||||
if microbatch_id >= actual_stage_num:
|
||||
if forward_only or not self.use_1F1B:
|
||||
ret_future[microbatch_id - actual_stage_num].wait()
|
||||
else:
|
||||
key = UniqueKey(microbatch_id - actual_stage_num, Phase.BACKWARD)
|
||||
first_worker_rref.rpc_sync().get_output_by_key(key)
|
||||
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
|
||||
|
||||
# set input
|
||||
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
||||
microbatch = microbatch.cuda()
|
||||
first_worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
|
||||
self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only)
|
||||
|
||||
# set labels
|
||||
if not forward_only and labels is not None:
|
||||
if labels is not None:
|
||||
microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
||||
microlabels = microlabels.cuda()
|
||||
last_worker_rref.remote().set_labels(microbatch_id, microlabels)
|
||||
self._set_labels(output_pp_ranks, microbatch_id, microlabels)
|
||||
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
ret_future[microbatch_id] = last_worker_rref.rpc_async().get_output_by_key(key)
|
||||
# get data asynchronously
|
||||
self._subscribe_forward(microbatch_id, output_pp_ranks, ret_future)
|
||||
|
||||
# wait for last backward in rank0
|
||||
if not forward_only:
|
||||
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
|
||||
first_worker_rref.rpc_sync().get_output_by_key(key)
|
||||
# wait for first rank to ensure all backwards are done
|
||||
self._ensure_backward(forward_only, input_pp_ranks)
|
||||
|
||||
# collect forward result
|
||||
# TODO : all the node to output
|
||||
forward_result = None
|
||||
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
|
||||
|
||||
for microbatch_id in range(self.num_microbatches):
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
ret = ret_future[microbatch_id].wait()
|
||||
if forward_result is None:
|
||||
forward_result = [[]] * len(ret)
|
||||
for i in range(len(forward_result)):
|
||||
forward_result[i].append(ret[i])
|
||||
|
||||
if hasattr(self, 'optimizer_class'):
|
||||
if not forward_only and labels is not None:
|
||||
# wait for all step
|
||||
# TODO : more elegant ?
|
||||
for pp_rank in self.pp_rank_to_worker_rref:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_sync().wait_for_step()
|
||||
|
@ -751,31 +808,3 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
|
||||
for fut in self.step_futs:
|
||||
fut.wait()
|
||||
|
||||
|
||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
module_partitions: List[nn.Module],
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
use_1F1B = False
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
|
||||
|
||||
|
||||
class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
module_partitions: List[nn.Module],
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
use_1F1B = True
|
||||
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
|
|
@ -0,0 +1,277 @@
|
|||
from typing import List, Callable, Dict
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.futures import Future
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
|
||||
from colossalai.pipeline.rpc._pipeline_base import PipelineEngineBase, WorkerBase, UniqueKey, Phase
|
||||
|
||||
# Implementation of different Pipeline schedule
|
||||
# <strategy>Worker defines the worker for each stage
|
||||
# <strategy>PipelineEngine is the class for use
|
||||
|
||||
|
||||
class FillDrainWorker(WorkerBase):
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
# execute backward first (if backward phase in work_list)
|
||||
num_microbatches = self.num_microbatches
|
||||
|
||||
if self.forward_times < num_microbatches:
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
else:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
|
||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
||||
|
||||
return target_key
|
||||
|
||||
|
||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
module_partitions: List[nn.Module],
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
use_1F1B = False
|
||||
|
||||
super().__init__(FillDrainWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
|
||||
criterion, metric, checkpoint)
|
||||
|
||||
|
||||
class OneFOneBWorker(WorkerBase):
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
# execute backward first (if backward phase in work_list)
|
||||
pp_rank = self.pp_rank
|
||||
actual_stage_num = self.actual_stage_num
|
||||
num_microbatches = self.num_microbatches
|
||||
is_last_stage = pp_rank == actual_stage_num - 1
|
||||
|
||||
if self.outstanding <= self.outstanding_range[0]:
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
elif self.outstanding >= self.outstanding_range[1]:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
else:
|
||||
raise ValueError("outstanding_range[1] - outstanding_range[0] must be in [0, 1]")
|
||||
|
||||
target_key = UniqueKey(target_microbatch_id, target_phase)
|
||||
|
||||
# change outstanding_range at:
|
||||
# 1. forward times reach actual_stage_num, this is the end of continuous forward
|
||||
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
|
||||
if not is_last_stage and \
|
||||
target_key.phase == Phase.FORWARD:
|
||||
if target_key.microbatch_id == actual_stage_num - 1:
|
||||
outstanding_min = actual_stage_num - pp_rank - 1
|
||||
outstanding_max = actual_stage_num - pp_rank
|
||||
self.outstanding_range = (outstanding_min, outstanding_max)
|
||||
elif target_key.microbatch_id == num_microbatches - 1:
|
||||
self.outstanding_range = (0, 0)
|
||||
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
||||
|
||||
return target_key
|
||||
|
||||
|
||||
class OneFOneBPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
module_partitions: List[nn.Module],
|
||||
stage_num: int,
|
||||
num_microbatches: int,
|
||||
device: str,
|
||||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!"
|
||||
use_1F1B = True
|
||||
|
||||
super().__init__(OneFOneBWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
|
||||
criterion, metric, checkpoint)
|
||||
|
||||
|
||||
class ChimeraWorker(WorkerBase):
|
||||
|
||||
def _get_producer_consumer(self) -> None:
|
||||
rank = self.pp_rank
|
||||
min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num
|
||||
max_pp_rank = min_pp_rank + self.actual_stage_num - 1
|
||||
|
||||
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
|
||||
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
|
||||
|
||||
# should be aranged in order, the order of the input of current forward
|
||||
self.producer_stage_ids = []
|
||||
self.consumer_stage_ids = []
|
||||
|
||||
# Just for demo
|
||||
prev_rank = rank - 1
|
||||
next_rank = rank + 1
|
||||
if prev_rank >= min_pp_rank:
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
if next_rank <= max_pp_rank:
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
pp_rank = self.pp_rank
|
||||
stage_num = self.actual_stage_num
|
||||
real_microbatch_num = self.num_microbatches // 2
|
||||
|
||||
if self.forward_times < real_microbatch_num:
|
||||
if (pp_rank + 1) % stage_num == 0: # last rank
|
||||
forward_blocks = self.forward_times // (self.num_microbatches // stage_num)
|
||||
if forward_blocks > self.backward_times:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
else:
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
else: # others
|
||||
target_phase = Phase.FORWARD
|
||||
target_microbatch_id = self.forward_times
|
||||
else:
|
||||
target_phase = Phase.BACKWARD
|
||||
target_microbatch_id = self.backward_times
|
||||
|
||||
# In up pipeline, microbatch_id to consume is 0, 2, 4 (2n)
|
||||
# In down pipeline, microbatch_id to consume is 1, 3, 5 (2n + 1)
|
||||
real_target_microbatch_id = target_microbatch_id * 2
|
||||
if pp_rank >= stage_num:
|
||||
real_target_microbatch_id += 1
|
||||
target_key = UniqueKey(real_target_microbatch_id, target_phase)
|
||||
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
|
||||
|
||||
return target_key
|
||||
|
||||
def is_first_stage(self):
|
||||
return (self.pp_rank % self.actual_stage_num) == 0
|
||||
|
||||
def is_last_stage(self):
|
||||
return (self.pp_rank % self.actual_stage_num) == self.actual_stage_num - 1
|
||||
|
||||
|
||||
class ChimeraPipelineEngine(PipelineEngineBase):
|
||||
|
||||
def __init__(self,
|
||||
module_partitions,
|
||||
stage_num,
|
||||
num_microbatches,
|
||||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"In Chimera, num_microbatches must be the multiply of stage_num!"
|
||||
use_1F1B = False
|
||||
chunk = 1
|
||||
super().__init__(ChimeraWorker, module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk,
|
||||
criterion, metric, checkpoint)
|
||||
|
||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
|
||||
input_worker_rrefs: List[PyRRef], output_worker_rrefs: List[PyRRef]):
|
||||
pass
|
||||
|
||||
def _create_pp_rank_to_rpc_worker_id(self) -> None:
|
||||
stage_num = self.stage_num
|
||||
self.pp_rank_to_rpc_worker_id = [0] * (stage_num * 2)
|
||||
for pp_rank in range(stage_num):
|
||||
self.pp_rank_to_rpc_worker_id[pp_rank] = pp_rank
|
||||
self.pp_rank_to_rpc_worker_id[pp_rank + stage_num] = stage_num - pp_rank - 1
|
||||
|
||||
def _create_pp_rank_to_module_partition_id(self) -> None:
|
||||
stage_num = self.stage_num
|
||||
self.pp_rank_to_module_partition_id = [0] * (stage_num * 2)
|
||||
for pp_rank in range(stage_num):
|
||||
self.pp_rank_to_module_partition_id[pp_rank] = pp_rank
|
||||
self.pp_rank_to_module_partition_id[pp_rank + stage_num] = pp_rank
|
||||
|
||||
def _create_ret_future(self, output_pp_ranks: List[int]) -> Dict[int, List[Future]]:
|
||||
num_microbatches = self.num_microbatches
|
||||
stage_num = self.stage_num
|
||||
up_ret_future = {pp_rank: [None] * num_microbatches for pp_rank in output_pp_ranks}
|
||||
down_ret_future = {pp_rank + stage_num: [None] * num_microbatches for pp_rank in output_pp_ranks}
|
||||
# merge up and down
|
||||
return {**up_ret_future, **down_ret_future}
|
||||
|
||||
def _set_input(self, input_pp_ranks: List[int], microbatch_id: int, microbatch, forward_only: bool):
|
||||
# offset is 0 for all the ranks in up pipeline
|
||||
# offset is stage_num for all the ranks in down pipeline
|
||||
offset = (microbatch_id % 2) * self.stage_num
|
||||
for pp_rank in input_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
||||
worker_rref.remote().set_input(microbatch_id, microbatch, forward_only)
|
||||
|
||||
def _set_labels(self, output_pp_ranks: List[int], microbatch_id: int, microlabels):
|
||||
# offset is 0 for all the ranks in up pipeline
|
||||
# offset is stage_num for all the ranks in down pipeline
|
||||
offset = (microbatch_id % 2) * self.stage_num
|
||||
for pp_rank in output_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
||||
worker_rref.remote().set_labels(microbatch_id, microlabels)
|
||||
|
||||
def _subscribe_forward(self, microbatch_id: int, output_pp_ranks: List[int], ret_future: Dict[int, List[Future]]):
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
offset = (microbatch_id % 2) * self.stage_num
|
||||
for pp_rank in output_pp_ranks:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank + offset]
|
||||
ret_future[pp_rank + offset][microbatch_id] = worker_rref.rpc_async().get_output_by_key(key)
|
||||
|
||||
def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]):
|
||||
stage_num = self.stage_num
|
||||
num_microbatches = self.num_microbatches
|
||||
if not forward_only:
|
||||
for pp_rank in input_pp_ranks:
|
||||
up_last_microbatch_id = num_microbatches - 2
|
||||
down_last_microbatch_id = num_microbatches - 1
|
||||
|
||||
up_worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
down_worker_rref = self.pp_rank_to_worker_rref[pp_rank + stage_num]
|
||||
|
||||
up_key = UniqueKey(up_last_microbatch_id, Phase.BACKWARD)
|
||||
down_key = UniqueKey(down_last_microbatch_id, Phase.BACKWARD)
|
||||
|
||||
up_worker_rref.rpc_sync().get_output_by_key(up_key)
|
||||
down_worker_rref.rpc_sync().get_output_by_key(down_key)
|
||||
|
||||
def _collect_forward_result(self, output_pp_ranks: List[int], ret_future: Dict[PyRRef, List[Future]]):
|
||||
"""Logic of collection of forward in Chimera.
|
||||
Currently, only one input one output model is supported
|
||||
"""
|
||||
stage_num = self.stage_num
|
||||
forward_result = []
|
||||
for pp_rank in output_pp_ranks:
|
||||
worker_forward_result = [None] * self.num_microbatches
|
||||
for microbatch_id in range(self.num_microbatches):
|
||||
offset = (microbatch_id % 2) * stage_num
|
||||
ret = ret_future[pp_rank + offset][microbatch_id].wait()
|
||||
worker_forward_result[microbatch_id] = ret
|
||||
|
||||
worker_forward_result = list(zip(*worker_forward_result))
|
||||
forward_result.extend(worker_forward_result)
|
||||
|
||||
return forward_result
|
Binary file not shown.
|
@ -0,0 +1,43 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
|
||||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
|
||||
epoch = args.epoch
|
||||
device = args.device
|
||||
stage_num = 4
|
||||
chunk = 1
|
||||
num_microbatches = 4
|
||||
actual_stage_num = 4
|
||||
use_checkpoint = False
|
||||
|
||||
sample_num = 1024
|
||||
feat_num = 10
|
||||
h = 10
|
||||
batch_size = 1024
|
||||
|
||||
assert sample_num % batch_size == 0
|
||||
|
||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
engine = ChimeraPipelineEngine(module_partitions=module_partitions,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
for _ in range(epoch):
|
||||
_ = engine.forward_backward(input_sample, forward_only=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
args.world_size = 4
|
||||
args.num_microbatches = 4
|
||||
rpc_run(args, run_master)
|
|
@ -3,7 +3,7 @@ from torch import nn
|
|||
from torch import autograd
|
||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
import os
|
||||
from typing import Callable, List, Optional, Type, Union
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from titans.dataloader.cifar10 import build_cifar
|
||||
from torchvision.models import resnet50
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
|
||||
from tqdm import tqdm
|
||||
|
||||
from rpc_test_utils import rpc_run, parse_args
|
||||
import colossalai
|
||||
import colossalai.nn as col_nn
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
|
||||
|
||||
def flatten(x):
|
||||
return torch.flatten(x, 1)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return torch.flatten(x, start_dim=1)
|
||||
|
||||
|
||||
def run_master(args):
|
||||
batch_size = args.batch_size
|
||||
chunk = args.chunk
|
||||
device = args.device
|
||||
world_size = args.world_size
|
||||
stage_num = world_size
|
||||
num_microbatches = args.num_microbatches
|
||||
|
||||
assert chunk == 1
|
||||
|
||||
pipelinable = PipelinableContext()
|
||||
|
||||
# build model partitions
|
||||
with pipelinable:
|
||||
# input : [B, 3, 32, 32]
|
||||
model = resnet50()
|
||||
|
||||
exec_seq = [
|
||||
'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', (flatten, "behind"), 'fc'
|
||||
]
|
||||
pipelinable.to_layer_list(exec_seq)
|
||||
module_partitions: List[PipelinableModel] = [
|
||||
pipelinable.partition(chunk, stage_num, pp_rank) for pp_rank in range(world_size)
|
||||
]
|
||||
|
||||
# build dataloader
|
||||
root = os.environ.get('DATA', './data')
|
||||
train_dataloader, test_dataloader = build_cifar(batch_size, root, padding=4, crop=32, resize=32)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
partition_1 = module_partitions[0]
|
||||
partition_2 = []
|
||||
for model in module_partitions[1]._module_list:
|
||||
partition_2.append(model)
|
||||
partition_2.insert(len(partition_2) - 1, Flatten())
|
||||
partition_2 = nn.Sequential(*partition_2)
|
||||
module_partitions = [partition_1, partition_2]
|
||||
|
||||
pp_engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
criterion=criterion,
|
||||
checkpoint=False)
|
||||
|
||||
pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
|
||||
s = time.time()
|
||||
|
||||
for bx, by in tqdm(train_dataloader):
|
||||
pp_engine.forward_backward(bx, labels=by, forward_only=False)
|
||||
|
||||
cost_time = time.time() - s
|
||||
|
||||
print("total cost time :", cost_time)
|
||||
print("cost time per batch:", cost_time / len(train_dataloader))
|
||||
|
||||
|
||||
@pytest.mark.skip("Test for performance, no need for CI")
|
||||
def main():
|
||||
args = parse_args()
|
||||
# this is due to limitation of partition function
|
||||
args.world_size = 2
|
||||
args.chunk = 1
|
||||
rpc_run(args, run_master)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch import autograd
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
|
@ -36,7 +36,7 @@ def run_master(args):
|
|||
chunk=chunk,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
forward_result = engine.forward_backward(input_sample)
|
||||
forward_result = engine.forward_backward(input_sample)[0]
|
||||
|
||||
cuda_rpc_result = []
|
||||
single_result = []
|
||||
|
|
Loading…
Reference in New Issue