InternLM/internlm/solver/pipeline_utils.py

35 lines
1.3 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert (
num_items % num_chunks == 0
), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
parts = [[] for _ in range(pipeline_parallel_size)]
partition_items = num_items // num_chunks
for idx in range(num_chunks):
base_idx = idx * partition_items
chunk_size = partition_items // pipeline_parallel_size
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
if chunk_size == 0:
raise ValueError("Some nodes in Pipeline have no requests")
for p in range(pipeline_parallel_size):
st = base_idx
base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx))
indexes = []
for _parts in parts:
for s, e in _parts:
indexes.extend(list(range(s, e)))
assert len(indexes) == len(set(indexes)), indexes # should have no duplicates
assert set(indexes) == set(list(range(num_items))), (indexes, num_items) # should have the same indexes as expected
return parts