mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
6.9 KiB
227 lines
6.9 KiB
3 years ago
|
import copy
|
||
|
import heapq
|
||
|
|
||
|
from colossalai.builder import build_model, build_layer
|
||
|
from colossalai.context.parallel_mode import ParallelMode
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.logging import get_global_dist_logger
|
||
|
from colossalai.utils import set_to_cuda
|
||
|
|
||
|
|
||
|
def _binary_partition(weights, st, ed):
|
||
|
"""Returns the binary partition position of `weights`, given the start
|
||
|
position `st` and the end position `ed`.
|
||
|
|
||
|
:param weights: A python list to be binary partitioned
|
||
|
:type weights: list
|
||
|
:param st: the start position of the binary partition
|
||
|
:type st: int
|
||
|
:param ed: the end postition of the binary partition
|
||
|
:type ed: int
|
||
|
:return: the binary partition position of `weights`
|
||
|
:rtype: int
|
||
|
"""
|
||
|
w_sum = weights[ed - 1]
|
||
|
prefix = 0
|
||
|
if st > 0:
|
||
|
w_sum -= weights[st - 1]
|
||
|
prefix = weights[st - 1]
|
||
|
minimum = float("inf")
|
||
|
for idx in range(st + 1, ed):
|
||
|
front = weights[idx - 1] - prefix
|
||
|
diff = abs(w_sum - 2 * front)
|
||
|
if diff < minimum:
|
||
|
pos = idx
|
||
|
minimum = diff
|
||
|
|
||
|
return st, pos, ed
|
||
|
|
||
|
|
||
|
def _heap_addition(weights, intervals, add_cnt):
|
||
|
"""
|
||
|
"""
|
||
|
def _heap_push(heap, st, ed):
|
||
|
value = weights[ed - 1]
|
||
|
if st > 0:
|
||
|
value -= weights[st - 1]
|
||
|
heapq.heappush(heap, (-value, st, ed))
|
||
|
|
||
|
ret_intervals = []
|
||
|
heap = []
|
||
|
|
||
|
for st, ed in intervals:
|
||
|
_heap_push(heap, st, ed)
|
||
|
|
||
|
while add_cnt > 0:
|
||
|
_, st, ed = heapq.heappop(heap)
|
||
|
if ed - st == 1:
|
||
|
ret_intervals.append((st, ed))
|
||
|
else:
|
||
|
l, m, r = _binary_partition(weights, st, ed)
|
||
|
_heap_push(heap, l, m)
|
||
|
_heap_push(heap, m, r)
|
||
|
add_cnt -= 1
|
||
|
|
||
|
while heap:
|
||
|
_, st, ed = heapq.heappop(heap)
|
||
|
ret_intervals.append((st, ed))
|
||
|
|
||
|
ret_intervals.sort()
|
||
|
return ret_intervals
|
||
|
|
||
|
|
||
|
def _calc_partitions(weights, value):
|
||
|
prev = 0
|
||
|
prefix = 0
|
||
|
num_block = 0
|
||
|
intervals = []
|
||
|
|
||
|
for idx, w in enumerate(weights):
|
||
|
if weights[idx] - prefix > value:
|
||
|
intervals.append((prev, idx))
|
||
|
prev = idx
|
||
|
prefix = weights[idx - 1]
|
||
|
num_block += 1
|
||
|
|
||
|
intervals.append((prev, len(weights)))
|
||
|
return num_block + 1, intervals
|
||
|
|
||
|
|
||
|
def _binary_search(weights, num):
|
||
|
length = len(weights)
|
||
|
prefix = [1 if w == 0 else w for w in weights]
|
||
|
for i in range(1, length):
|
||
|
prefix[i] += prefix[i - 1]
|
||
|
|
||
|
lower_bound = max(weights)
|
||
|
upper_bound = prefix[length - 1]
|
||
|
|
||
|
while upper_bound > lower_bound:
|
||
|
mid = (upper_bound + lower_bound) // 2
|
||
|
number, _ = _calc_partitions(prefix, mid)
|
||
|
if number <= num:
|
||
|
upper_bound = mid
|
||
|
else:
|
||
|
lower_bound = mid + 1
|
||
|
|
||
|
num_block, intervals = _calc_partitions(prefix, upper_bound)
|
||
|
if num_block < num:
|
||
|
intervals = _heap_addition(prefix, intervals, num - num_block)
|
||
|
|
||
|
return intervals
|
||
|
|
||
|
|
||
|
def _partition_uniform(num_items, num_parts, num_chunks):
|
||
|
assert num_items % num_chunks == 0, \
|
||
|
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
||
|
|
||
|
logger = get_global_dist_logger()
|
||
|
parts = [[] for _ in range(num_parts)]
|
||
|
partition_items = num_items // num_chunks
|
||
|
for idx in range(num_chunks):
|
||
|
base_idx = idx * partition_items
|
||
|
chunk_size = partition_items // num_parts
|
||
|
left = num_parts - partition_items % num_parts
|
||
|
if chunk_size == 0:
|
||
|
logger.warning("Some nodes in Pipeline have no requests")
|
||
|
|
||
|
for p in range(num_parts):
|
||
|
st = base_idx
|
||
|
base_idx += chunk_size + (p >= left)
|
||
|
parts[p].append((st, base_idx))
|
||
|
|
||
|
return parts
|
||
|
|
||
|
|
||
|
def _partition_balanced(weights, num_parts, num_chunks):
|
||
|
num_total = num_parts * num_chunks
|
||
|
num_items = len(weights)
|
||
|
if num_items <= num_total:
|
||
|
return _partition_uniform(num_items, num_parts, num_chunks)
|
||
|
|
||
|
intervals = _binary_search(weights, num_total)
|
||
|
|
||
|
current = 0
|
||
|
parts = [[] for _ in range(num_parts)]
|
||
|
for inter in intervals:
|
||
|
parts[current].append(inter)
|
||
|
current = (current + 1) % num_parts
|
||
|
|
||
|
return parts
|
||
|
|
||
|
|
||
|
class ModelInitializer():
|
||
|
def __init__(self, config, num_chunks, verbose=False):
|
||
|
self.num_chunks = num_chunks
|
||
|
self.ori_model = build_model(config)
|
||
|
self.layers = self.ori_model.layers_cfg
|
||
|
layer_length = len(self.layers)
|
||
|
self.verbose = verbose
|
||
|
self._logger = get_global_dist_logger()
|
||
|
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
|
||
|
|
||
|
def model_initialize(self, partition_method='parameter'):
|
||
|
# Some space for initializing comunication groups
|
||
|
self._interval = None
|
||
|
self._partition_layers(method=partition_method)
|
||
|
models = self._build()
|
||
|
model = set_to_cuda(models)
|
||
|
|
||
|
return model
|
||
|
|
||
|
def _partition_layers(self, method):
|
||
|
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||
|
|
||
|
method = method.lower()
|
||
|
# Make a partition
|
||
|
if method == 'layer':
|
||
|
num_layers = len(self.layers)
|
||
|
self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks)
|
||
|
elif method == 'parameter':
|
||
|
param_counts = self._count_layer_params()
|
||
|
# print_rank_0(param_counts)
|
||
|
self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks)
|
||
|
else:
|
||
|
assert method == 'layer', "Method should be a pre-set string"
|
||
|
|
||
|
# Display the partition
|
||
|
if gpc.get_global_rank() == 0 and self.verbose:
|
||
|
log_str = 'Layer allocation after partitioning: \n'
|
||
|
for stage in range(pipeline_parallel_size):
|
||
|
|
||
|
num_layers = 0
|
||
|
for st, ed in self.parts[stage]:
|
||
|
num_layers += ed - st
|
||
|
|
||
|
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
|
||
|
for st, ed in self.parts[stage]:
|
||
|
for idx, layer in enumerate(self.layers[st: ed]):
|
||
|
log_str += f'\t{idx + st:2d}: {layer}\n'
|
||
|
self._logger.info(log_str)
|
||
|
|
||
|
# Save the partition
|
||
|
self._interval = self.parts[pipeline_rank]
|
||
|
|
||
|
def _build(self):
|
||
|
"""Build model from the layer cfg according to the partition
|
||
|
"""
|
||
|
models = []
|
||
|
for st, ed in self._interval:
|
||
|
model = copy.copy(self.ori_model)
|
||
|
model.build_from_cfg(st, ed)
|
||
|
models.append(model)
|
||
|
|
||
|
return models
|
||
|
|
||
|
def _count_layer_params(self):
|
||
|
"""Count the number of parameters in each layer
|
||
|
"""
|
||
|
param_counts = [0] * len(self.layers)
|
||
|
for idx, cfg in enumerate(self.layers):
|
||
|
layer = build_layer(cfg)
|
||
|
params = filter(lambda p: p.requires_grad, layer.parameters())
|
||
|
param_counts[idx] = sum(p.numel() for p in params)
|
||
|
|
||
|
return param_counts
|