Browse Source

add doc for search chunk

pull/2364/head
oahzxl 2 years ago
parent
commit
a68d240ed5
  1. 76
      colossalai/autochunk/search_chunk.py

76
colossalai/autochunk/search_chunk.py

@ -1,4 +1,7 @@
import copy
from typing import Any, Dict, Iterable, List, Tuple
from torch.fx.node import Node
from .estimate_memory import EstimateMemory
from .reorder_graph import ReorderGraph
@ -13,6 +16,34 @@ from .utils import (
class SearchChunk(object):
"""
This is the core class for AutoChunk.
It defines the framework of the strategy of AutoChunk.
Chunks will be selected one by one utill search stops.
The chunk search is as follows:
1. find the peak memory node
2. find the max chunk region according to the peak memory node
3. find all possible chunk regions in the max chunk region
4. find the best chunk region for current status
5. goto 1
Attributes:
gm: graph model
print_mem (bool): print estimated memory
trace_index: trace the flow of every dim of every node to find all free dims
trace_flow: determine the region chunk strategy
reorder_graph: reorder nodes to improve chunk efficiency
estimate_memory: estimate memory with chunk
select_chunk: select the best chunk region
Args:
gm: graph model
max_memory (int): max memory in MB
print_mem (bool): print estimated memory
"""
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm
self.print_mem = print_mem
@ -33,24 +64,37 @@ class SearchChunk(object):
max_idx = mem_peak.index(max_value)
return max_idx
def _get_free_var(self):
def _get_free_var_idx(self) -> List:
"""
Get free var index
Returns:
free_var_idx (List): all indexs of free vars
"""
free_var_idx = []
for idx, n in enumerate(self.trace_index.node_list):
if n.op == "placeholder":
free_var_idx.append(idx)
return free_var_idx
def _get_min_free_var(self, active_node_list, free_vars):
min_len = 999
for idx, n in enumerate(active_node_list):
if idx in free_vars:
continue
if len(n) < min_len:
min_len = len(n)
return min_len
def _search_max_chunk_region(
self, active_node: List, peak_node: Node, chunk_regions: List
) -> Tuple:
"""
Search max chunk region according to peak memory node
Chunk region starts extending from the peak node, stops where free var num is min
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions):
free_vars = self._get_free_var()
Args:
active_node (List): active node status for every node
peak_node (Node): peak memory node
chunk_regions (List): chunk region info
Returns:
chunk_region_start (int)
chunk_region_end (int)
"""
free_vars = self._get_free_var_idx()
free_var_num = len(free_vars)
active_node_num = [len(i) for i in active_node]
min_active_node_num = min(active_node_num[free_var_num:])
@ -92,16 +136,6 @@ class SearchChunk(object):
chunk_region_end = region[0] - 1
return chunk_region_start, chunk_region_end
def _is_not_compute(self, trace, chunk_range, dim_idx):
if trace["idx"][dim_idx] not in trace["compute"]:
return True
if trace["idx"][dim_idx] in trace["compute"] and all(
i < chunk_range[0] or i > chunk_range[1]
for i in trace["compute"][trace["idx"][dim_idx]]
):
return True
return False
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
start_traces = input_trace[start_idx]
end_trace = output_trace[end_idx]

Loading…
Cancel
Save