[Pipeline Middleware ] Fix deadlock when num_microbatch=num_stage (#2156)

* add splitter

* polish code

* remove comment

* fix async nan by moving to cpu first

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
pull/2184/head
Ziyue Jiang 2022-12-23 11:38:43 +08:00 committed by GitHub
parent 937f404253
commit 59e343328d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 58 deletions

View File

@ -9,6 +9,30 @@ def pipe_split():
pass
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgnode_split_pass, simpliy split graph by node number.
"""
mod_graph = gm.graph
avg_num_node = len(mod_graph.nodes) // pp_size
accumulate_num_node = 0
for node in mod_graph.nodes:
if pp_size <= 1:
break
accumulate_num_node += 1
if accumulate_num_node >= avg_num_node:
accumulate_num_node = 0
pp_size -= 1
if node.next.op == 'output':
with mod_graph.inserting_before(node):
split_node = mod_graph.create_node('call_function', pipe_split)
else:
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
return gm
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In balanced_split_pass, we split module by the size of parameters(weights+bias).

View File

@ -16,6 +16,7 @@ from colossalai.pipeline.middleware import Partition, PartitionInputVal, Partiti
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (
get_batch_lengths,
pyobj_map,
pytree_filter,
pytree_map,
split_batch,
@ -199,36 +200,28 @@ class WorkerBase(ABC):
with self.output_list_condition_lock:
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
output_work_item = self.output_list[key]
output = output_work_item.output
if not ref_use and output_work_item.phase != Phase.INPUT:
self.output_list.pop(key)
if not ref_use:
if not ref_use and output_work_item.phase != Phase.INPUT:
output_work_item.refcount += 1
refcount = output_work_item.refcount
output = output_work_item.output
if output_work_item.phase == Phase.FORWARD:
# lifecycle management for DAG scheduler
if output_work_item.phase == Phase.FORWARD:
lifecycle = len(self.get_consumer_stage_ids())
if self.is_model_output(): # an extra reference for scheduler collecting results
lifecycle += 1
with self.output_list_condition_lock:
# all consumers have been satisfied, the work_item can be released
# or put it into work list again.
if refcount < lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
elif output_work_item.phase == Phase.BACKWARD:
lifecycle = len(self.get_producer_stage_ids())
if self._is_last_step(output_work_item):
lifecycle += 1 # an extra reference for scheduler collecting results
with self.output_list_condition_lock:
# all producers have been satisfied, the work_item can be released
# or put it into work list again.
if refcount < lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
if self._is_last_step(output_work_item): # an extra reference for ensure_backward
lifecycle += 1
else:
lifecycle = 0
refcount = 0
with self.output_list_condition_lock:
if refcount < lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
@ -689,10 +682,12 @@ class WorkerBase(ABC):
else:
args_kwargs = self._get_real_args_kwargs_fwd(args)
if not forward_only:
pytree_map(args_kwargs,
lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
process_types=torch.Tensor)
# if not forward_only:
# pytree_map(args_kwargs,
# lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
# process_types=torch.Tensor)
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
args, kwargs = data_process_func(args_kwargs)
@ -762,6 +757,9 @@ class WorkerBase(ABC):
if is_last_stage: # if it is the last stage, trigger backward automatic
self._begin_backward(microbatch_id)
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
elif phase == Phase.BACKWARD:
# remind its producer to get data before backward
if not is_first_stage:
@ -807,6 +805,8 @@ class WorkerBase(ABC):
stage_outputs = filtered_outputs
grad_tensors = filtered_grads
grad_tensors = pyobj_map(grad_tensors, fn=lambda x: x.to(self.device),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor
@ -818,6 +818,9 @@ class WorkerBase(ABC):
consume_result.append(arg.grad)
else:
consume_result.append(None)
consume_result = pyobj_map(
consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@ -882,9 +885,6 @@ class WorkerBase(ABC):
# if is last step in one batch reset context and do step
if self._is_last_step(work_item):
self._hook_before_step()
if hasattr(self, 'optimizer') and not work_item.forward_only:
self.step()
self._wait_for_reset()
# reset context and resume loop
@ -904,23 +904,12 @@ class WorkerBase(ABC):
self.reset_condition.notify_all()
def initialize_optimizer(self, optimizer_class: type, **kwargs):
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
# After tracer fixed, remove this part.
if len(list(self.module_partition.parameters())) > 0:
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
self.step_lock = threading.Lock()
self.step_lock.acquire()
def wait_for_step(self):
self.step_lock.acquire()
def step(self):
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
# After tracer fixed, remove this part.
if len(list(self.module_partition.parameters())) > 0:
self._hook_before_step()
self.optimizer.step()
self.optimizer.zero_grad()
self.step_lock.release()
class PipelineEngineBase(ABC, nn.Module):
@ -1176,10 +1165,7 @@ class PipelineEngineBase(ABC, nn.Module):
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
if not forward_only and hasattr(self, 'optimizer_class'):
# wait for all step
for pp_rank in self.pp_rank_to_worker_rref:
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
worker_rref.rpc_sync().wait_for_step()
self.step()
self._reset_worker() # reset worker attributes for next batch
return forward_result

View File

@ -3,11 +3,12 @@ from typing import Callable, Dict, List
import torch
import torch.distributed as dist
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
# Implementation of different Pipeline schedule
# <strategy>Worker defines the worker for each stage
# <strategy>PipelineEngine is the class for use
@ -86,7 +87,7 @@ class OneFOneBWorker(WorkerBase):
outstanding_min = actual_stage_num - pp_rank - 1
outstanding_max = actual_stage_num - pp_rank
self.outstanding_range = (outstanding_min, outstanding_max)
elif target_key.microbatch_id == num_microbatches - 1:
if target_key.microbatch_id == num_microbatches - 1:
self.outstanding_range = (0, 0)
return target_key

View File

@ -6,11 +6,25 @@ from typing import Any, Callable, Dict, List, Tuple, Type, Union
import torch
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.futures import Future
from colossalai.initialize import launch
from colossalai.pipeline.pipeline_process_group import ppg
def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any:
if isinstance(obj, process_types):
return fn(obj)
elif type(obj) is dict:
return {k: pyobj_map(obj[k], fn, process_types) for k in obj}
elif type(obj) is tuple:
return tuple(pyobj_map(o, fn, process_types) for o in obj)
elif type(obj) is list:
return list(pyobj_map(o, fn, process_types) for o in obj)
else:
return obj
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
"""process object recursively, like pytree
@ -57,6 +71,7 @@ def split_batch(batch: Any, start, stop, device: str):
def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True)
def pytree_filter(fn, obj, process_types):
if obj is None:
return None