[pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework (#1684)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing

* [pipeline/pytree] add pytree to process args and kwargs | provide  to process args and kwargs after forward

* [pipeline/fix-bug] num_microbatches support any integrate | stable chimera | launch tool for rpc pp framework
pull/1686/head
Kirigaya Kazuto 2022-10-10 16:01:02 +08:00 committed by GitHub
parent e5ab6be72e
commit 0df5034a36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 24 deletions

View File

@ -50,6 +50,7 @@ class PipelineProcessGroup:
self.is_initialize = True
# lock
self.initialise_lock = threading.Lock()
self.chimera_lock = threading.Lock()
def _initialize_process_group(self):

View File

@ -3,9 +3,7 @@ from enum import Enum
from typing import List, Any, Tuple, Dict, Callable
from functools import partial
from abc import ABC, abstractmethod
import sys
import os
import time
import math
import inspect
import torch
@ -831,13 +829,16 @@ class PipelineEngineBase(ABC, nn.Module):
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
batch_lengths = get_batch_lengths(batch)
batch_length = batch_lengths[0]
if labels is not None and not forward_only:
assert hasattr(
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
num_microbatches = self.num_microbatches
microbatch_size = batch_lengths[0] // num_microbatches
assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal"
microbatch_size = math.ceil(batch_length / num_microbatches)
device = self.device
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
@ -852,7 +853,7 @@ class PipelineEngineBase(ABC, nn.Module):
# to prevent exceed of wait limitations
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
batch_start = microbatch_size * microbatch_id
batch_end = batch_start + microbatch_size
batch_end = min(batch_start + microbatch_size, batch_length)
# set input
microbatch = split_batch(batch, batch_start, batch_end, device)

View File

@ -1,4 +1,5 @@
from typing import List, Callable, Dict
import threading
import torch
import torch.distributed as dist
@ -81,7 +82,8 @@ class OneFOneBWorker(WorkerBase):
# 2. forward times reach num_microbatches, this is the end of 1F1B mode
if not is_last_stage and \
target_key.phase == Phase.FORWARD:
if target_key.microbatch_id == actual_stage_num - 1:
if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2:
# Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
@ -186,6 +188,19 @@ class ChimeraWorker(WorkerBase):
# init group for chimera in ppg
ppg.get_chimera_all_reduce_group(pp_rank)
# lock for step sync
self.step_sync_lock = threading.Lock()
self.step_sync_lock.acquire()
self.have_grad_lock = threading.Lock()
self.have_grad_lock.acquire()
def _get_lock_gradient(self):
self.have_grad_lock.acquire()
grads = self.get_parameter_gradients()
self.step_sync_lock.release()
return grads
def is_first_stage(self):
return (self.pp_rank % self.actual_stage_num) == 0
@ -214,27 +229,22 @@ class ChimeraWorker(WorkerBase):
return local_device_pp_ranks
def _hook_before_step(self):
self.have_grad_lock.release()
pp_rank = self.pp_rank
orders = self._get_step_order()
step_index = orders.index(pp_rank)
stage_num = self.actual_stage_num
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
# if currrent pp_rank is not the first to do step
# wait its previous pp_rank finish step
all_reduce_group = ppg.get_chimera_all_reduce_group(self.pp_rank)
grads = self.get_parameter_gradients()
# print(self.pp_rank, "begin all reduce", torch.cuda.max_memory_allocated(ppg.get_local_pp_rank()), torch.cuda.max_memory_reserved(ppg.get_local_pp_rank()))
if step_index == 1:
ppg.chimera_step_lock.acquire()
# print(f'rank_{self.pp_rank} before all reduce')
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
# print(f'rank_{self.pp_rank} after all reduce')
if step_index == 0:
ppg.chimera_step_lock.release()
# send
co_worker = self.pp_rank_to_worker_rref[co_pp_rank]
co_grads = co_worker.rpc_sync()._get_lock_gradient()
# sync
self.step_sync_lock.acquire()
for i in range(len(grads)):
grads[i] += co_grads[i]
class ChimeraPipelineEngine(PipelineEngineBase):
@ -257,8 +267,8 @@ class ChimeraPipelineEngine(PipelineEngineBase):
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint, data_process_func)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
pass
def _create_pp_rank_to_rpc_worker_id(self) -> None:

View File

@ -1,10 +1,18 @@
from typing import List, Any, Tuple, Dict, Callable, Type, Union
import os
import warnings
import argparse
import torch
import torch.multiprocessing as mp
from torch.futures import Future
import torch.distributed.rpc as rpc
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from colorama import Back, Style
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
# config for debug and test
use_color_debug = False
@ -87,3 +95,57 @@ def get_real_args_kwargs(args_or_kwargs):
args_or_kwargs = flatten_args
return args_or_kwargs
def run_worker(rank, args, master_func):
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
device = args.device
world_size = args.world_size
dp_degree = args.dp_degree
tp_degree = args.tp_degree
num_worker_threads = args.num_worker_threads
host = args.master_addr
port = args.master_port
backend = 'nccl' if device == 'cuda' else 'gloo'
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
ppg.set_global_info(rank=rank,
world_size=world_size,
dp_degree=dp_degree,
tp_degree=tp_degree,
num_worker_threads=num_worker_threads,
device=device)
ppg.args = args
# in rpc mode, only rank 0 is needed to be coded
if rank == 0:
master_func(args)
# barrier here
if _is_current_rpc_agent_set():
rpc.shutdown()
else:
warnings.warn("RPC has not been initialized")
def rpc_run(args, master_func):
world_size = args.world_size
mp.spawn(run_worker, args=(args, master_func), nprocs=world_size)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=1)
parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--dp_degree', type=int, default=1)
parser.add_argument('--tp_degree', type=int, default=1)
parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--use_checkpoint', action='store_true')
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29020')
parser.add_argument('--num_worker_threads', type=str, default=128)
return parser.parse_args()