[Pipeline Middleware ] Fix deadlock when num_microbatch=num_stage (#2156)

* add splitter

* polish code

* remove comment

* fix async nan by moving to cpu first

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
pull/2184/head
Ziyue Jiang 2022-12-23 11:38:43 +08:00 committed by GitHub
parent 937f404253
commit 59e343328d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 58 deletions

View File

@ -9,6 +9,30 @@ def pipe_split():
pass pass
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgnode_split_pass, simpliy split graph by node number.
"""
mod_graph = gm.graph
avg_num_node = len(mod_graph.nodes) // pp_size
accumulate_num_node = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
accumulate_num_node += 1
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
if node.next.op == 'output':
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
""" """
In balanced_split_pass, we split module by the size of parameters(weights+bias). In balanced_split_pass, we split module by the size of parameters(weights+bias).

View File

@ -16,6 +16,7 @@ from colossalai.pipeline.middleware import Partition, PartitionInputVal, Partiti
from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import ( from colossalai.pipeline.rpc.utils import (
get_batch_lengths, get_batch_lengths,
pyobj_map,
pytree_filter, pytree_filter,
pytree_map, pytree_map,
split_batch, split_batch,
@ -199,36 +200,28 @@ class WorkerBase(ABC):
with self.output_list_condition_lock: with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list) self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key] output_work_item = self.output_list[key]
output = output_work_item.output
if not ref_use and output_work_item.phase != Phase.INPUT:
self.output_list.pop(key) self.output_list.pop(key)
if not ref_use: if not ref_use and output_work_item.phase != Phase.INPUT:
output_work_item.refcount += 1 output_work_item.refcount += 1
refcount = output_work_item.refcount refcount = output_work_item.refcount
output = output_work_item.output
if output_work_item.phase == Phase.FORWARD:
# lifecycle management for DAG scheduler # lifecycle management for DAG scheduler
if output_work_item.phase == Phase.FORWARD:
lifecycle = len(self.get_consumer_stage_ids()) lifecycle = len(self.get_consumer_stage_ids())
if self.is_model_output(): # an extra reference for scheduler collecting results if self.is_model_output(): # an extra reference for scheduler collecting results
lifecycle += 1 lifecycle += 1
with self.output_list_condition_lock:
# all consumers have been satisfied, the work_item can be released
# or put it into work list again.
if refcount < lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
elif output_work_item.phase == Phase.BACKWARD: elif output_work_item.phase == Phase.BACKWARD:
lifecycle = len(self.get_producer_stage_ids()) lifecycle = len(self.get_producer_stage_ids())
if self._is_last_step(output_work_item): if self._is_last_step(output_work_item): # an extra reference for ensure_backward
lifecycle += 1 # an extra reference for scheduler collecting results lifecycle += 1
with self.output_list_condition_lock:
# all producers have been satisfied, the work_item can be released
# or put it into work list again.
if refcount < lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
else: else:
lifecycle = 0
refcount = 0
with self.output_list_condition_lock: with self.output_list_condition_lock:
if refcount < lifecycle:
self.output_list[key] = output_work_item self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all() self.output_list_condition_lock.notify_all()
@ -689,10 +682,12 @@ class WorkerBase(ABC):
else: else:
args_kwargs = self._get_real_args_kwargs_fwd(args) args_kwargs = self._get_real_args_kwargs_fwd(args)
if not forward_only: # if not forward_only:
pytree_map(args_kwargs, # pytree_map(args_kwargs,
lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False), # lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
process_types=torch.Tensor) # process_types=torch.Tensor)
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
args, kwargs = data_process_func(args_kwargs) args, kwargs = data_process_func(args_kwargs)
@ -762,6 +757,9 @@ class WorkerBase(ABC):
if is_last_stage: # if it is the last stage, trigger backward automatic if is_last_stage: # if it is the last stage, trigger backward automatic
self._begin_backward(microbatch_id) self._begin_backward(microbatch_id)
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
elif phase == Phase.BACKWARD: elif phase == Phase.BACKWARD:
# remind its producer to get data before backward # remind its producer to get data before backward
if not is_first_stage: if not is_first_stage:
@ -807,6 +805,8 @@ class WorkerBase(ABC):
stage_outputs = filtered_outputs stage_outputs = filtered_outputs
grad_tensors = filtered_grads grad_tensors = filtered_grads
grad_tensors = pyobj_map(grad_tensors, fn=lambda x: x.to(self.device),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
autograd.backward(stage_outputs, grad_tensors=grad_tensors) autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor # collect grad of input tensor
@ -818,6 +818,9 @@ class WorkerBase(ABC):
consume_result.append(arg.grad) consume_result.append(arg.grad)
else: else:
consume_result.append(None) consume_result.append(None)
consume_result = pyobj_map(
consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
else: else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@ -882,9 +885,6 @@ class WorkerBase(ABC):
# if is last step in one batch reset context and do step # if is last step in one batch reset context and do step
if self._is_last_step(work_item): if self._is_last_step(work_item):
self._hook_before_step()
if hasattr(self, 'optimizer') and not work_item.forward_only:
self.step()
self._wait_for_reset() self._wait_for_reset()
# reset context and resume loop # reset context and resume loop
@ -904,23 +904,12 @@ class WorkerBase(ABC):
self.reset_condition.notify_all() self.reset_condition.notify_all()
def initialize_optimizer(self, optimizer_class: type, **kwargs): def initialize_optimizer(self, optimizer_class: type, **kwargs):
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
# After tracer fixed, remove this part.
if len(list(self.module_partition.parameters())) > 0:
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
self.step_lock = threading.Lock()
self.step_lock.acquire()
def wait_for_step(self):
self.step_lock.acquire()
def step(self): def step(self):
# TODO(jiangziyue) it's temporary code to deal with empty module partition. self._hook_before_step()
# After tracer fixed, remove this part.
if len(list(self.module_partition.parameters())) > 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.step_lock.release()
class PipelineEngineBase(ABC, nn.Module): class PipelineEngineBase(ABC, nn.Module):
@ -1176,10 +1165,7 @@ class PipelineEngineBase(ABC, nn.Module):
forward_result = self._collect_forward_result(output_pp_ranks, ret_future) forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
if not forward_only and hasattr(self, 'optimizer_class'): if not forward_only and hasattr(self, 'optimizer_class'):
# wait for all step self.step()
for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().wait_for_step()
self._reset_worker() # reset worker attributes for next batch self._reset_worker() # reset worker attributes for next batch
return forward_result return forward_result

View File

@ -3,11 +3,12 @@ from typing import Callable, Dict, List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
from torch._C._distributed_rpc import PyRRef from torch._C._distributed_rpc import PyRRef
from torch.futures import Future from torch.futures import Future
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
# Implementation of different Pipeline schedule # Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage # <strategy>Worker defines the worker for each stage
# <strategy>PipelineEngine is the class for use # <strategy>PipelineEngine is the class for use
@ -86,7 +87,7 @@ class OneFOneBWorker(WorkerBase):
outstanding_min = actual_stage_num - pp_rank - 1 outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max) self.outstanding_range = (outstanding_min, outstanding_max)
elif target_key.microbatch_id == num_microbatches - 1: if target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0) self.outstanding_range = (0, 0)
return target_key return target_key

View File

@ -6,11 +6,25 @@ from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch import torch
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.futures import Future from torch.futures import Future
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any:
if isinstance(obj, process_types):
return fn(obj)
elif type(obj) is dict:
return {k: pyobj_map(obj[k], fn, process_types) for k in obj}
elif type(obj) is tuple:
return tuple(pyobj_map(o, fn, process_types) for o in obj)
elif type(obj) is list:
return list(pyobj_map(o, fn, process_types) for o in obj)
else:
return obj
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
"""process object recursively, like pytree """process object recursively, like pytree
@ -57,6 +71,7 @@ def split_batch(batch: Any, start, stop, device: str):
def type_detail(obj): def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True) return pytree_map(obj, lambda x: type(x), map_all=True)
def pytree_filter(fn, obj, process_types): def pytree_filter(fn, obj, process_types):
if obj is None: if obj is None:
return None return None