import copy import heapq from colossalai.builder import build_model, build_layer from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_global_dist_logger from colossalai.utils import set_to_cuda def _binary_partition(weights, st, ed): """Returns the binary partition position of `weights`, given the start position `st` and the end position `ed`. :param weights: A python list to be binary partitioned :type weights: list :param st: the start position of the binary partition :type st: int :param ed: the end postition of the binary partition :type ed: int :return: the binary partition position of `weights` :rtype: int """ w_sum = weights[ed - 1] prefix = 0 if st > 0: w_sum -= weights[st - 1] prefix = weights[st - 1] minimum = float("inf") for idx in range(st + 1, ed): front = weights[idx - 1] - prefix diff = abs(w_sum - 2 * front) if diff < minimum: pos = idx minimum = diff return st, pos, ed def _heap_addition(weights, intervals, add_cnt): """ """ def _heap_push(heap, st, ed): value = weights[ed - 1] if st > 0: value -= weights[st - 1] heapq.heappush(heap, (-value, st, ed)) ret_intervals = [] heap = [] for st, ed in intervals: _heap_push(heap, st, ed) while add_cnt > 0: _, st, ed = heapq.heappop(heap) if ed - st == 1: ret_intervals.append((st, ed)) else: l, m, r = _binary_partition(weights, st, ed) _heap_push(heap, l, m) _heap_push(heap, m, r) add_cnt -= 1 while heap: _, st, ed = heapq.heappop(heap) ret_intervals.append((st, ed)) ret_intervals.sort() return ret_intervals def _calc_partitions(weights, value): prev = 0 prefix = 0 num_block = 0 intervals = [] for idx, w in enumerate(weights): if weights[idx] - prefix > value: intervals.append((prev, idx)) prev = idx prefix = weights[idx - 1] num_block += 1 intervals.append((prev, len(weights))) return num_block + 1, intervals def _binary_search(weights, num): length = len(weights) prefix = [1 if w == 0 else w for w in weights] for i in range(1, length): prefix[i] += prefix[i - 1] lower_bound = max(weights) upper_bound = prefix[length - 1] while upper_bound > lower_bound: mid = (upper_bound + lower_bound) // 2 number, _ = _calc_partitions(prefix, mid) if number <= num: upper_bound = mid else: lower_bound = mid + 1 num_block, intervals = _calc_partitions(prefix, upper_bound) if num_block < num: intervals = _heap_addition(prefix, intervals, num - num_block) return intervals def _partition_uniform(num_items, num_parts, num_chunks): assert num_items % num_chunks == 0, \ "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" logger = get_global_dist_logger() parts = [[] for _ in range(num_parts)] partition_items = num_items // num_chunks for idx in range(num_chunks): base_idx = idx * partition_items chunk_size = partition_items // num_parts left = num_parts - partition_items % num_parts if chunk_size == 0: logger.warning("Some nodes in Pipeline have no requests") for p in range(num_parts): st = base_idx base_idx += chunk_size + (p >= left) parts[p].append((st, base_idx)) return parts def _partition_balanced(weights, num_parts, num_chunks): num_total = num_parts * num_chunks num_items = len(weights) if num_items <= num_total: return _partition_uniform(num_items, num_parts, num_chunks) intervals = _binary_search(weights, num_total) current = 0 parts = [[] for _ in range(num_parts)] for inter in intervals: parts[current].append(inter) current = (current + 1) % num_parts return parts class ModelInitializer(): def __init__(self, config, num_chunks, verbose=False): self.num_chunks = num_chunks self.ori_model = build_model(config) self.layers = self.ori_model.layers_cfg layer_length = len(self.layers) self.verbose = verbose self._logger = get_global_dist_logger() self._logger.info(f"The total length of layers is {layer_length}", ranks=[0]) def model_initialize(self, partition_method='parameter'): # Some space for initializing comunication groups self._interval = None self._partition_layers(method=partition_method) models = self._build() model = set_to_cuda(models) return model def _partition_layers(self, method): pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) method = method.lower() # Make a partition if method == 'layer': num_layers = len(self.layers) self.parts = _partition_uniform(num_layers, pipeline_parallel_size, self.num_chunks) elif method == 'parameter': param_counts = self._count_layer_params() # print_rank_0(param_counts) self.parts = _partition_balanced(param_counts, pipeline_parallel_size, self.num_chunks) else: assert method == 'layer', "Method should be a pre-set string" # Display the partition if gpc.get_global_rank() == 0 and self.verbose: log_str = 'Layer allocation after partitioning: \n' for stage in range(pipeline_parallel_size): num_layers = 0 for st, ed in self.parts[stage]: num_layers += ed - st log_str += f'\n===== stage={stage}, layers={num_layers} =====\n' for st, ed in self.parts[stage]: for idx, layer in enumerate(self.layers[st: ed]): log_str += f'\t{idx + st:2d}: {layer}\n' self._logger.info(log_str) # Save the partition self._interval = self.parts[pipeline_rank] def _build(self): """Build model from the layer cfg according to the partition """ models = [] for st, ed in self._interval: model = copy.copy(self.ori_model) model.build_from_cfg(st, ed) models.append(model) return models def _count_layer_params(self): """Count the number of parameters in each layer """ param_counts = [0] * len(self.layers) for idx, cfg in enumerate(self.layers): layer = build_layer(cfg) params = filter(lambda p: p.requires_grad, layer.parameters()) param_counts[idx] = sum(p.numel() for p in params) return param_counts