mirror of https://github.com/InternLM/InternLM
35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from internlm.utils.logger import get_logger
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
|
assert (
|
|
num_items % num_chunks == 0
|
|
), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
|
|
|
|
parts = [[] for _ in range(pipeline_parallel_size)]
|
|
partition_items = num_items // num_chunks
|
|
for idx in range(num_chunks):
|
|
base_idx = idx * partition_items
|
|
chunk_size = partition_items // pipeline_parallel_size
|
|
left = pipeline_parallel_size - partition_items % pipeline_parallel_size
|
|
if chunk_size == 0:
|
|
raise ValueError("Some nodes in Pipeline have no requests")
|
|
|
|
for p in range(pipeline_parallel_size):
|
|
st = base_idx
|
|
base_idx += chunk_size + (p >= left)
|
|
parts[p].append((st, base_idx))
|
|
|
|
indexes = []
|
|
for _parts in parts:
|
|
for s, e in _parts:
|
|
indexes.extend(list(range(s, e)))
|
|
assert len(indexes) == len(set(indexes)), indexes # should have no duplicates
|
|
assert set(indexes) == set(list(range(num_items))), (indexes, num_items) # should have the same indexes as expected
|
|
return parts
|