mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
22 lines
706 B
22 lines
706 B
from typing import Tuple |
|
|
|
import torch |
|
|
|
|
|
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: |
|
"""Return the local shard of a full tensor.""" |
|
# Shard using torch.chunk to match all-gather/reduce-scatter. |
|
chunks = list(torch.flatten(tensor).chunk(world_size)) |
|
while len(chunks) < world_size: |
|
chunks.append(chunks[0].new_empty(0)) |
|
|
|
# Determine number of padding elements. |
|
num_to_pad = chunks[0].numel() - chunks[rank].numel() |
|
assert num_to_pad >= 0, num_to_pad |
|
|
|
shard = torch.zeros_like(chunks[0]) |
|
length = chunks[rank].size(0) |
|
shard_temp = shard[:length] |
|
shard_temp.copy_(chunks[rank]) |
|
|
|
return shard, num_to_pad
|
|
|