mirror of https://github.com/hpcaitech/ColossalAI
[Pipeline Middleware ] Fix deadlock when num_microbatch=num_stage (#2156)
* add splitter * polish code * remove comment * fix async nan by moving to cpu first Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/2184/head
parent
937f404253
commit
59e343328d
|
@ -9,6 +9,30 @@ def pipe_split():
|
|||
pass
|
||||
|
||||
|
||||
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
"""
|
||||
In avgnode_split_pass, simpliy split graph by node number.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
avg_num_node = len(mod_graph.nodes) // pp_size
|
||||
accumulate_num_node = 0
|
||||
for node in mod_graph.nodes:
|
||||
if pp_size <= 1:
|
||||
break
|
||||
accumulate_num_node += 1
|
||||
if accumulate_num_node >= avg_num_node:
|
||||
accumulate_num_node = 0
|
||||
pp_size -= 1
|
||||
if node.next.op == 'output':
|
||||
with mod_graph.inserting_before(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
else:
|
||||
with mod_graph.inserting_after(node):
|
||||
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
"""
|
||||
In balanced_split_pass, we split module by the size of parameters(weights+bias).
|
||||
|
|
|
@ -16,6 +16,7 @@ from colossalai.pipeline.middleware import Partition, PartitionInputVal, Partiti
|
|||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (
|
||||
get_batch_lengths,
|
||||
pyobj_map,
|
||||
pytree_filter,
|
||||
pytree_map,
|
||||
split_batch,
|
||||
|
@ -199,36 +200,28 @@ class WorkerBase(ABC):
|
|||
with self.output_list_condition_lock:
|
||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||
output_work_item = self.output_list[key]
|
||||
output = output_work_item.output
|
||||
if not ref_use and output_work_item.phase != Phase.INPUT:
|
||||
self.output_list.pop(key)
|
||||
|
||||
if not ref_use:
|
||||
if not ref_use and output_work_item.phase != Phase.INPUT:
|
||||
output_work_item.refcount += 1
|
||||
refcount = output_work_item.refcount
|
||||
output = output_work_item.output
|
||||
|
||||
if output_work_item.phase == Phase.FORWARD:
|
||||
# lifecycle management for DAG scheduler
|
||||
if output_work_item.phase == Phase.FORWARD:
|
||||
lifecycle = len(self.get_consumer_stage_ids())
|
||||
if self.is_model_output(): # an extra reference for scheduler collecting results
|
||||
lifecycle += 1
|
||||
with self.output_list_condition_lock:
|
||||
# all consumers have been satisfied, the work_item can be released
|
||||
# or put it into work list again.
|
||||
if refcount < lifecycle:
|
||||
self.output_list[key] = output_work_item
|
||||
self.output_list_condition_lock.notify_all()
|
||||
elif output_work_item.phase == Phase.BACKWARD:
|
||||
lifecycle = len(self.get_producer_stage_ids())
|
||||
if self._is_last_step(output_work_item):
|
||||
lifecycle += 1 # an extra reference for scheduler collecting results
|
||||
with self.output_list_condition_lock:
|
||||
# all producers have been satisfied, the work_item can be released
|
||||
# or put it into work list again.
|
||||
if refcount < lifecycle:
|
||||
self.output_list[key] = output_work_item
|
||||
self.output_list_condition_lock.notify_all()
|
||||
if self._is_last_step(output_work_item): # an extra reference for ensure_backward
|
||||
lifecycle += 1
|
||||
else:
|
||||
lifecycle = 0
|
||||
refcount = 0
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
if refcount < lifecycle:
|
||||
self.output_list[key] = output_work_item
|
||||
self.output_list_condition_lock.notify_all()
|
||||
|
||||
|
@ -689,10 +682,12 @@ class WorkerBase(ABC):
|
|||
else:
|
||||
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
||||
|
||||
if not forward_only:
|
||||
pytree_map(args_kwargs,
|
||||
lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
|
||||
process_types=torch.Tensor)
|
||||
# if not forward_only:
|
||||
# pytree_map(args_kwargs,
|
||||
# lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
|
||||
# process_types=torch.Tensor)
|
||||
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
|
||||
args, kwargs = data_process_func(args_kwargs)
|
||||
|
||||
|
@ -762,6 +757,9 @@ class WorkerBase(ABC):
|
|||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
self._begin_backward(microbatch_id)
|
||||
|
||||
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
|
||||
elif phase == Phase.BACKWARD:
|
||||
# remind its producer to get data before backward
|
||||
if not is_first_stage:
|
||||
|
@ -807,6 +805,8 @@ class WorkerBase(ABC):
|
|||
stage_outputs = filtered_outputs
|
||||
grad_tensors = filtered_grads
|
||||
|
||||
grad_tensors = pyobj_map(grad_tensors, fn=lambda x: x.to(self.device),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||
|
||||
# collect grad of input tensor
|
||||
|
@ -818,6 +818,9 @@ class WorkerBase(ABC):
|
|||
consume_result.append(arg.grad)
|
||||
else:
|
||||
consume_result.append(None)
|
||||
consume_result = pyobj_map(
|
||||
consume_result, fn=lambda x: x.to('cpu'),
|
||||
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||
|
@ -882,9 +885,6 @@ class WorkerBase(ABC):
|
|||
|
||||
# if is last step in one batch reset context and do step
|
||||
if self._is_last_step(work_item):
|
||||
self._hook_before_step()
|
||||
if hasattr(self, 'optimizer') and not work_item.forward_only:
|
||||
self.step()
|
||||
self._wait_for_reset()
|
||||
|
||||
# reset context and resume loop
|
||||
|
@ -904,23 +904,12 @@ class WorkerBase(ABC):
|
|||
self.reset_condition.notify_all()
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||
# After tracer fixed, remove this part.
|
||||
if len(list(self.module_partition.parameters())) > 0:
|
||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||
self.step_lock = threading.Lock()
|
||||
self.step_lock.acquire()
|
||||
|
||||
def wait_for_step(self):
|
||||
self.step_lock.acquire()
|
||||
|
||||
def step(self):
|
||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||
# After tracer fixed, remove this part.
|
||||
if len(list(self.module_partition.parameters())) > 0:
|
||||
self._hook_before_step()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.step_lock.release()
|
||||
|
||||
|
||||
class PipelineEngineBase(ABC, nn.Module):
|
||||
|
@ -1176,10 +1165,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
|
||||
|
||||
if not forward_only and hasattr(self, 'optimizer_class'):
|
||||
# wait for all step
|
||||
for pp_rank in self.pp_rank_to_worker_rref:
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_sync().wait_for_step()
|
||||
self.step()
|
||||
|
||||
self._reset_worker() # reset worker attributes for next batch
|
||||
return forward_result
|
||||
|
|
|
@ -3,11 +3,12 @@ from typing import Callable, Dict, List
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.futures import Future
|
||||
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
|
||||
|
||||
# Implementation of different Pipeline schedule
|
||||
# <strategy>Worker defines the worker for each stage
|
||||
# <strategy>PipelineEngine is the class for use
|
||||
|
@ -86,7 +87,7 @@ class OneFOneBWorker(WorkerBase):
|
|||
outstanding_min = actual_stage_num - pp_rank - 1
|
||||
outstanding_max = actual_stage_num - pp_rank
|
||||
self.outstanding_range = (outstanding_min, outstanding_max)
|
||||
elif target_key.microbatch_id == num_microbatches - 1:
|
||||
if target_key.microbatch_id == num_microbatches - 1:
|
||||
self.outstanding_range = (0, 0)
|
||||
|
||||
return target_key
|
||||
|
|
|
@ -6,11 +6,25 @@ from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
|||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||
from torch.futures import Future
|
||||
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
|
||||
|
||||
def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any:
|
||||
if isinstance(obj, process_types):
|
||||
return fn(obj)
|
||||
elif type(obj) is dict:
|
||||
return {k: pyobj_map(obj[k], fn, process_types) for k in obj}
|
||||
elif type(obj) is tuple:
|
||||
return tuple(pyobj_map(o, fn, process_types) for o in obj)
|
||||
elif type(obj) is list:
|
||||
return list(pyobj_map(o, fn, process_types) for o in obj)
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
||||
"""process object recursively, like pytree
|
||||
|
@ -57,6 +71,7 @@ def split_batch(batch: Any, start, stop, device: str):
|
|||
def type_detail(obj):
|
||||
return pytree_map(obj, lambda x: type(x), map_all=True)
|
||||
|
||||
|
||||
def pytree_filter(fn, obj, process_types):
|
||||
if obj is None:
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue