mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
6.3 KiB
207 lines
6.3 KiB
3 years ago
|
import heapq
|
||
|
import inspect
|
||
|
|
||
|
from colossalai.logging import get_dist_logger
|
||
|
from typing import List
|
||
|
|
||
|
def _binary_partition(weights: List, start: int, end: int):
|
||
|
"""Returns the binary partition position of `weights`, given the start
|
||
|
position `st` and the end position `ed`.
|
||
|
|
||
|
Args:
|
||
|
weights (list): A python list to be binary partitioned
|
||
|
start (int): the start position of the binary partition
|
||
|
end (int): the end position of the binary partition
|
||
|
|
||
|
Returns:
|
||
|
int: the binary partition position of `weights`
|
||
|
"""
|
||
|
w_sum = weights[end - 1]
|
||
|
prefix = 0
|
||
|
if start > 0:
|
||
|
w_sum -= weights[start - 1]
|
||
|
prefix = weights[start - 1]
|
||
|
minimum = float("inf")
|
||
|
for idx in range(start + 1, end):
|
||
|
front = weights[idx - 1] - prefix
|
||
|
diff = abs(w_sum - 2 * front)
|
||
|
if diff < minimum:
|
||
|
pos = idx
|
||
|
minimum = diff
|
||
|
|
||
|
return start, pos, end
|
||
|
|
||
|
|
||
|
def _heap_addition(weights: List, intervals: int, add_cnt: int):
|
||
|
"""
|
||
|
"""
|
||
|
|
||
|
def _heap_push(heap, st, ed):
|
||
|
value = weights[ed - 1]
|
||
|
if st > 0:
|
||
|
value -= weights[st - 1]
|
||
|
heapq.heappush(heap, (-value, st, ed))
|
||
|
|
||
|
ret_intervals = []
|
||
|
heap = []
|
||
|
|
||
|
for st, ed in intervals:
|
||
|
_heap_push(heap, st, ed)
|
||
|
|
||
|
while add_cnt > 0:
|
||
|
_, st, ed = heapq.heappop(heap)
|
||
|
if ed - st == 1:
|
||
|
ret_intervals.append((st, ed))
|
||
|
else:
|
||
|
l, m, r = _binary_partition(weights, st, ed)
|
||
|
_heap_push(heap, l, m)
|
||
|
_heap_push(heap, m, r)
|
||
|
add_cnt -= 1
|
||
|
|
||
|
while heap:
|
||
|
_, st, ed = heapq.heappop(heap)
|
||
|
ret_intervals.append((st, ed))
|
||
|
|
||
|
ret_intervals.sort()
|
||
|
return ret_intervals
|
||
|
|
||
|
|
||
|
def _calc_partitions(weights, value):
|
||
|
prev = 0
|
||
|
prefix = 0
|
||
|
num_block = 0
|
||
|
intervals = []
|
||
|
|
||
|
for idx, w in enumerate(weights):
|
||
|
if weights[idx] - prefix > value:
|
||
|
intervals.append((prev, idx))
|
||
|
prev = idx
|
||
|
prefix = weights[idx - 1]
|
||
|
num_block += 1
|
||
|
|
||
|
intervals.append((prev, len(weights)))
|
||
|
return num_block + 1, intervals
|
||
|
|
||
|
|
||
|
def _binary_search(weights, num):
|
||
|
length = len(weights)
|
||
|
prefix = [1 if w == 0 else w for w in weights]
|
||
|
for i in range(1, length):
|
||
|
prefix[i] += prefix[i - 1]
|
||
|
|
||
|
lower_bound = max(weights)
|
||
|
upper_bound = prefix[length - 1]
|
||
|
|
||
|
while upper_bound > lower_bound:
|
||
|
mid = (upper_bound + lower_bound) // 2
|
||
|
number, _ = _calc_partitions(prefix, mid)
|
||
|
if number <= num:
|
||
|
upper_bound = mid
|
||
|
else:
|
||
|
lower_bound = mid + 1
|
||
|
|
||
|
num_block, intervals = _calc_partitions(prefix, upper_bound)
|
||
|
if num_block < num:
|
||
|
intervals = _heap_addition(prefix, intervals, num - num_block)
|
||
|
|
||
|
return intervals
|
||
|
|
||
|
|
||
|
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||
|
assert num_items % num_chunks == 0, \
|
||
|
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
||
|
|
||
|
logger = get_dist_logger()
|
||
|
parts = [[] for _ in range(pipeline_parallel_size)]
|
||
|
partition_items = num_items // num_chunks
|
||
|
for idx in range(num_chunks):
|
||
|
base_idx = idx * partition_items
|
||
|
chunk_size = partition_items // pipeline_parallel_size
|
||
|
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
|
||
|
if chunk_size == 0:
|
||
|
logger.warning("Some nodes in Pipeline have no requests")
|
||
|
|
||
|
for p in range(pipeline_parallel_size):
|
||
|
st = base_idx
|
||
|
base_idx += chunk_size + (p >= left)
|
||
|
parts[p].append((st, base_idx))
|
||
|
|
||
|
return parts
|
||
|
|
||
|
|
||
|
def partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
||
|
num_total = pipeline_parallel_size * num_chunks
|
||
|
num_items = len(weights)
|
||
|
if num_items <= num_total:
|
||
|
return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
|
||
|
|
||
|
intervals = _binary_search(weights, num_total)
|
||
|
|
||
|
current = 0
|
||
|
parts = [[] for _ in range(pipeline_parallel_size)]
|
||
|
for inter in intervals:
|
||
|
parts[current].append(inter)
|
||
|
current = (current + 1) % pipeline_parallel_size
|
||
|
|
||
|
return parts
|
||
|
|
||
|
|
||
|
def build_kwargs_for_module(function, kw_dict):
|
||
|
"""
|
||
|
Generally, the first argument of module.forward is an input tensor come from the previous layer.
|
||
|
Therefore, we just filter the kwargs from second element of the dictionary.
|
||
|
"""
|
||
|
sig = inspect.signature(function)
|
||
|
if len(sig.parameters) <= 1:
|
||
|
return None
|
||
|
args_name_list = list(sig.parameters.keys())
|
||
|
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
|
||
|
return kw_dict
|
||
|
|
||
|
|
||
|
def build_kwargs_for_function(function, kw_dict):
|
||
|
sig = inspect.signature(function)
|
||
|
kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters}
|
||
|
if len(kw_dict) == 0:
|
||
|
return None
|
||
|
return kw_dict
|
||
|
|
||
|
|
||
|
def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
|
||
|
"""
|
||
|
We suppose the callable object passed to to_layer_list method in two purpose:
|
||
|
a. use the callable object to modify input tensor, such as \
|
||
|
lambda x: torch.flatten(x, 1)
|
||
|
b. use the callable object to modify kwargs value, such as \
|
||
|
def foo(attention_mask=None):
|
||
|
if attention_mask is not None:
|
||
|
batch_size = input_ids.shape[0]
|
||
|
attention_mask = attention_mask.view(batch_size, -1)
|
||
|
return attention_mask
|
||
|
"""
|
||
|
|
||
|
if kw_dict is not None:
|
||
|
rst = func(**kw_dict)
|
||
|
if isinstance(rst, tuple):
|
||
|
for i, k in enumerate(kw_dict.keys()):
|
||
|
kwargs[k] = rst[i]
|
||
|
else:
|
||
|
for k in kw_dict.keys():
|
||
|
kwargs[k] = rst
|
||
|
return input_tensor
|
||
|
return func(input_tensor)
|
||
|
|
||
|
|
||
|
def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
|
||
|
|
||
|
assert func_key in func_dict, f"{func_key} is not in the function_dict."
|
||
|
funcs_to_exec = func_dict[func_key]
|
||
|
if isinstance(funcs_to_exec, list):
|
||
|
for f in funcs_to_exec:
|
||
|
f_kwargs = build_kwargs_for_function(f, kwargs)
|
||
|
input_tensor = exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs)
|
||
|
else:
|
||
|
f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs)
|
||
|
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
|
||
|
|
||
|
return input_tensor
|