Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training.
It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor.
## đ Design
Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension.
Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below:
```text
[1, 2, 3, 4 ]
A = [4, 5, 6, 7 ]
[8, 9, 10, 11]
[12, 13, 14, 15]
```
`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology.
- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.)