mirror of https://github.com/hpcaitech/ColossalAI
[Pipeline Middleware ] Fix deadlock when num_microbatch=num_stage (#2156)
* add splitter * polish code * remove comment * fix async nan by moving to cpu first Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/2184/head
parent
937f404253
commit
59e343328d
|
@ -9,6 +9,30 @@ def pipe_split():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||||
|
"""
|
||||||
|
In avgnode_split_pass, simpliy split graph by node number.
|
||||||
|
"""
|
||||||
|
mod_graph = gm.graph
|
||||||
|
avg_num_node = len(mod_graph.nodes) // pp_size
|
||||||
|
accumulate_num_node = 0
|
||||||
|
for node in mod_graph.nodes:
|
||||||
|
if pp_size <= 1:
|
||||||
|
break
|
||||||
|
accumulate_num_node += 1
|
||||||
|
if accumulate_num_node >= avg_num_node:
|
||||||
|
accumulate_num_node = 0
|
||||||
|
pp_size -= 1
|
||||||
|
if node.next.op == 'output':
|
||||||
|
with mod_graph.inserting_before(node):
|
||||||
|
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||||
|
else:
|
||||||
|
with mod_graph.inserting_after(node):
|
||||||
|
split_node = mod_graph.create_node('call_function', pipe_split)
|
||||||
|
gm.recompile()
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||||
"""
|
"""
|
||||||
In balanced_split_pass, we split module by the size of parameters(weights+bias).
|
In balanced_split_pass, we split module by the size of parameters(weights+bias).
|
||||||
|
|
|
@ -16,6 +16,7 @@ from colossalai.pipeline.middleware import Partition, PartitionInputVal, Partiti
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
from colossalai.pipeline.rpc.utils import (
|
from colossalai.pipeline.rpc.utils import (
|
||||||
get_batch_lengths,
|
get_batch_lengths,
|
||||||
|
pyobj_map,
|
||||||
pytree_filter,
|
pytree_filter,
|
||||||
pytree_map,
|
pytree_map,
|
||||||
split_batch,
|
split_batch,
|
||||||
|
@ -199,36 +200,28 @@ class WorkerBase(ABC):
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||||
output_work_item = self.output_list[key]
|
output_work_item = self.output_list[key]
|
||||||
|
output = output_work_item.output
|
||||||
|
if not ref_use and output_work_item.phase != Phase.INPUT:
|
||||||
self.output_list.pop(key)
|
self.output_list.pop(key)
|
||||||
|
|
||||||
if not ref_use:
|
if not ref_use and output_work_item.phase != Phase.INPUT:
|
||||||
output_work_item.refcount += 1
|
output_work_item.refcount += 1
|
||||||
refcount = output_work_item.refcount
|
refcount = output_work_item.refcount
|
||||||
output = output_work_item.output
|
|
||||||
|
|
||||||
if output_work_item.phase == Phase.FORWARD:
|
|
||||||
# lifecycle management for DAG scheduler
|
# lifecycle management for DAG scheduler
|
||||||
|
if output_work_item.phase == Phase.FORWARD:
|
||||||
lifecycle = len(self.get_consumer_stage_ids())
|
lifecycle = len(self.get_consumer_stage_ids())
|
||||||
if self.is_model_output(): # an extra reference for scheduler collecting results
|
if self.is_model_output(): # an extra reference for scheduler collecting results
|
||||||
lifecycle += 1
|
lifecycle += 1
|
||||||
with self.output_list_condition_lock:
|
|
||||||
# all consumers have been satisfied, the work_item can be released
|
|
||||||
# or put it into work list again.
|
|
||||||
if refcount < lifecycle:
|
|
||||||
self.output_list[key] = output_work_item
|
|
||||||
self.output_list_condition_lock.notify_all()
|
|
||||||
elif output_work_item.phase == Phase.BACKWARD:
|
elif output_work_item.phase == Phase.BACKWARD:
|
||||||
lifecycle = len(self.get_producer_stage_ids())
|
lifecycle = len(self.get_producer_stage_ids())
|
||||||
if self._is_last_step(output_work_item):
|
if self._is_last_step(output_work_item): # an extra reference for ensure_backward
|
||||||
lifecycle += 1 # an extra reference for scheduler collecting results
|
lifecycle += 1
|
||||||
with self.output_list_condition_lock:
|
|
||||||
# all producers have been satisfied, the work_item can be released
|
|
||||||
# or put it into work list again.
|
|
||||||
if refcount < lifecycle:
|
|
||||||
self.output_list[key] = output_work_item
|
|
||||||
self.output_list_condition_lock.notify_all()
|
|
||||||
else:
|
else:
|
||||||
|
lifecycle = 0
|
||||||
|
refcount = 0
|
||||||
|
|
||||||
with self.output_list_condition_lock:
|
with self.output_list_condition_lock:
|
||||||
|
if refcount < lifecycle:
|
||||||
self.output_list[key] = output_work_item
|
self.output_list[key] = output_work_item
|
||||||
self.output_list_condition_lock.notify_all()
|
self.output_list_condition_lock.notify_all()
|
||||||
|
|
||||||
|
@ -689,10 +682,12 @@ class WorkerBase(ABC):
|
||||||
else:
|
else:
|
||||||
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
args_kwargs = self._get_real_args_kwargs_fwd(args)
|
||||||
|
|
||||||
if not forward_only:
|
# if not forward_only:
|
||||||
pytree_map(args_kwargs,
|
# pytree_map(args_kwargs,
|
||||||
lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
|
# lambda x: x.requires_grad_(True) if torch.is_floating_point(x) else x.requires_grad_(False),
|
||||||
process_types=torch.Tensor)
|
# process_types=torch.Tensor)
|
||||||
|
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
|
||||||
|
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||||
|
|
||||||
args, kwargs = data_process_func(args_kwargs)
|
args, kwargs = data_process_func(args_kwargs)
|
||||||
|
|
||||||
|
@ -762,6 +757,9 @@ class WorkerBase(ABC):
|
||||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||||
self._begin_backward(microbatch_id)
|
self._begin_backward(microbatch_id)
|
||||||
|
|
||||||
|
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
|
||||||
|
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||||
|
|
||||||
elif phase == Phase.BACKWARD:
|
elif phase == Phase.BACKWARD:
|
||||||
# remind its producer to get data before backward
|
# remind its producer to get data before backward
|
||||||
if not is_first_stage:
|
if not is_first_stage:
|
||||||
|
@ -807,6 +805,8 @@ class WorkerBase(ABC):
|
||||||
stage_outputs = filtered_outputs
|
stage_outputs = filtered_outputs
|
||||||
grad_tensors = filtered_grads
|
grad_tensors = filtered_grads
|
||||||
|
|
||||||
|
grad_tensors = pyobj_map(grad_tensors, fn=lambda x: x.to(self.device),
|
||||||
|
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||||
|
|
||||||
# collect grad of input tensor
|
# collect grad of input tensor
|
||||||
|
@ -818,6 +818,9 @@ class WorkerBase(ABC):
|
||||||
consume_result.append(arg.grad)
|
consume_result.append(arg.grad)
|
||||||
else:
|
else:
|
||||||
consume_result.append(None)
|
consume_result.append(None)
|
||||||
|
consume_result = pyobj_map(
|
||||||
|
consume_result, fn=lambda x: x.to('cpu'),
|
||||||
|
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||||
|
@ -882,9 +885,6 @@ class WorkerBase(ABC):
|
||||||
|
|
||||||
# if is last step in one batch reset context and do step
|
# if is last step in one batch reset context and do step
|
||||||
if self._is_last_step(work_item):
|
if self._is_last_step(work_item):
|
||||||
self._hook_before_step()
|
|
||||||
if hasattr(self, 'optimizer') and not work_item.forward_only:
|
|
||||||
self.step()
|
|
||||||
self._wait_for_reset()
|
self._wait_for_reset()
|
||||||
|
|
||||||
# reset context and resume loop
|
# reset context and resume loop
|
||||||
|
@ -904,23 +904,12 @@ class WorkerBase(ABC):
|
||||||
self.reset_condition.notify_all()
|
self.reset_condition.notify_all()
|
||||||
|
|
||||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
|
||||||
# After tracer fixed, remove this part.
|
|
||||||
if len(list(self.module_partition.parameters())) > 0:
|
|
||||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||||
self.step_lock = threading.Lock()
|
|
||||||
self.step_lock.acquire()
|
|
||||||
|
|
||||||
def wait_for_step(self):
|
|
||||||
self.step_lock.acquire()
|
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
self._hook_before_step()
|
||||||
# After tracer fixed, remove this part.
|
|
||||||
if len(list(self.module_partition.parameters())) > 0:
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.step_lock.release()
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineEngineBase(ABC, nn.Module):
|
class PipelineEngineBase(ABC, nn.Module):
|
||||||
|
@ -1176,10 +1165,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
||||||
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
|
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
|
||||||
|
|
||||||
if not forward_only and hasattr(self, 'optimizer_class'):
|
if not forward_only and hasattr(self, 'optimizer_class'):
|
||||||
# wait for all step
|
self.step()
|
||||||
for pp_rank in self.pp_rank_to_worker_rref:
|
|
||||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
|
||||||
worker_rref.rpc_sync().wait_for_step()
|
|
||||||
|
|
||||||
self._reset_worker() # reset worker attributes for next batch
|
self._reset_worker() # reset worker attributes for next batch
|
||||||
return forward_result
|
return forward_result
|
||||||
|
|
|
@ -3,11 +3,12 @@ from typing import Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
|
||||||
from colossalai.pipeline.rpc._pipeline_base import (Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem)
|
|
||||||
from torch._C._distributed_rpc import PyRRef
|
from torch._C._distributed_rpc import PyRRef
|
||||||
from torch.futures import Future
|
from torch.futures import Future
|
||||||
|
|
||||||
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
from colossalai.pipeline.rpc._pipeline_base import Phase, PipelineEngineBase, UniqueKey, WorkerBase, WorkItem
|
||||||
|
|
||||||
# Implementation of different Pipeline schedule
|
# Implementation of different Pipeline schedule
|
||||||
# <strategy>Worker defines the worker for each stage
|
# <strategy>Worker defines the worker for each stage
|
||||||
# <strategy>PipelineEngine is the class for use
|
# <strategy>PipelineEngine is the class for use
|
||||||
|
@ -86,7 +87,7 @@ class OneFOneBWorker(WorkerBase):
|
||||||
outstanding_min = actual_stage_num - pp_rank - 1
|
outstanding_min = actual_stage_num - pp_rank - 1
|
||||||
outstanding_max = actual_stage_num - pp_rank
|
outstanding_max = actual_stage_num - pp_rank
|
||||||
self.outstanding_range = (outstanding_min, outstanding_max)
|
self.outstanding_range = (outstanding_min, outstanding_max)
|
||||||
elif target_key.microbatch_id == num_microbatches - 1:
|
if target_key.microbatch_id == num_microbatches - 1:
|
||||||
self.outstanding_range = (0, 0)
|
self.outstanding_range = (0, 0)
|
||||||
|
|
||||||
return target_key
|
return target_key
|
||||||
|
|
|
@ -6,11 +6,25 @@ from typing import Any, Callable, Dict, List, Tuple, Type, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed.rpc as rpc
|
import torch.distributed.rpc as rpc
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.initialize import launch
|
|
||||||
from colossalai.pipeline.pipeline_process_group import ppg
|
|
||||||
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
from torch._C._distributed_rpc import _is_current_rpc_agent_set
|
||||||
from torch.futures import Future
|
from torch.futures import Future
|
||||||
|
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.pipeline.pipeline_process_group import ppg
|
||||||
|
|
||||||
|
|
||||||
|
def pyobj_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = ()) -> Any:
|
||||||
|
if isinstance(obj, process_types):
|
||||||
|
return fn(obj)
|
||||||
|
elif type(obj) is dict:
|
||||||
|
return {k: pyobj_map(obj[k], fn, process_types) for k in obj}
|
||||||
|
elif type(obj) is tuple:
|
||||||
|
return tuple(pyobj_map(o, fn, process_types) for o in obj)
|
||||||
|
elif type(obj) is list:
|
||||||
|
return list(pyobj_map(o, fn, process_types) for o in obj)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
||||||
"""process object recursively, like pytree
|
"""process object recursively, like pytree
|
||||||
|
@ -57,6 +71,7 @@ def split_batch(batch: Any, start, stop, device: str):
|
||||||
def type_detail(obj):
|
def type_detail(obj):
|
||||||
return pytree_map(obj, lambda x: type(x), map_all=True)
|
return pytree_map(obj, lambda x: type(x), map_all=True)
|
||||||
|
|
||||||
|
|
||||||
def pytree_filter(fn, obj, process_types):
|
def pytree_filter(fn, obj, process_types):
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return None
|
return None
|
||||||
|
|
Loading…
Reference in New Issue