ColossalAI/colossalai/device/README.md

74 lines
1.7 KiB
Markdown
Raw Normal View History

2023-06-08 02:18:17 +00:00
# 🗄 Device
## 📚 Table of Contents
- [🗄 Device](#-device)
- [📚 Table of Contents](#-table-of-contents)
- [🔗 Introduction](#-introduction)
- [📝 Design](#-design)
- [🔨 Usage](#-usage)
## 🔗 Introduction
This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.
## 📝 Design
This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.
## 🔨 Usage
- Create a device mesh
```python
# this is the list of global ranks involved in the device mesh
# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
physical_mesh_id = torch.arange(4)
mesh_shape = [2, 2]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
```
- View the mesh
```python
# view the mesh shape
# expect output
# [2, 2]
print(device_mesh.shape)
# view the logical mesh with global ranks
# expect output
# [
# [0, 1],
# [2, 3]
# ]
print(device_mesh.logical_mesh_id)
# view the number of devices in the mesh
# expect output
# 4
print(device_mesh.num_devices)
```
- Initialize the process group
```python
# intialize process group
device_mesh.init_logical_process_group()
# get the process group for a rank with respect to an axis
# this is the process group involving global ranks 0 and 2
print(device_mesh.get_process_group(axis=0, global_rank=0))
# get the ranks in the process with respect to an axis
# expect output
# [0, 2]
print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0))
```