2021-12-29 15:32:10 +00:00
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
from ..parallel_2d._operation import split_tensor_2d
|
|
|
|
from ..parallel_2p5d._operation import split_tensor_2p5d
|
2022-02-14 03:15:02 +00:00
|
|
|
from ..parallel_3d._operation import split_batch_3d
|
2021-12-29 15:32:10 +00:00
|
|
|
from ..utils import get_tensor_parallel_mode
|
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d}
|
2021-12-29 15:32:10 +00:00
|
|
|
|
|
|
|
|
2022-02-14 03:15:02 +00:00
|
|
|
def partition_batch(input_) -> Tensor:
|
2021-12-29 15:32:10 +00:00
|
|
|
tensor_parallel_mode = get_tensor_parallel_mode()
|
|
|
|
if tensor_parallel_mode in _parallel_split_batch:
|
2021-12-30 07:56:46 +00:00
|
|
|
if isinstance(input_, dict):
|
|
|
|
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
|
2021-12-29 15:32:10 +00:00
|
|
|
else:
|
|
|
|
return _parallel_split_batch[tensor_parallel_mode](input_)
|
|
|
|
else:
|
|
|
|
return input_
|