mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			35 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			35 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
#!/usr/bin/env python
 | 
						|
# -*- encoding: utf-8 -*-
 | 
						|
 | 
						|
from internlm.utils.logger import get_logger
 | 
						|
 | 
						|
logger = get_logger(__file__)
 | 
						|
 | 
						|
 | 
						|
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
 | 
						|
    assert (
 | 
						|
        num_items % num_chunks == 0
 | 
						|
    ), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
 | 
						|
 | 
						|
    parts = [[] for _ in range(pipeline_parallel_size)]
 | 
						|
    partition_items = num_items // num_chunks
 | 
						|
    for idx in range(num_chunks):
 | 
						|
        base_idx = idx * partition_items
 | 
						|
        chunk_size = partition_items // pipeline_parallel_size
 | 
						|
        left = pipeline_parallel_size - partition_items % pipeline_parallel_size
 | 
						|
        if chunk_size == 0:
 | 
						|
            raise ValueError("Some nodes in Pipeline have no requests")
 | 
						|
 | 
						|
        for p in range(pipeline_parallel_size):
 | 
						|
            st = base_idx
 | 
						|
            base_idx += chunk_size + (p >= left)
 | 
						|
            parts[p].append((st, base_idx))
 | 
						|
 | 
						|
    indexes = []
 | 
						|
    for _parts in parts:
 | 
						|
        for s, e in _parts:
 | 
						|
            indexes.extend(list(range(s, e)))
 | 
						|
    assert len(indexes) == len(set(indexes)), indexes  # should have no duplicates
 | 
						|
    assert set(indexes) == set(list(range(num_items))), (indexes, num_items)  # should have the same indexes as expected
 | 
						|
    return parts
 |