mirror of https://github.com/hpcaitech/ColossalAI
74 lines
1.7 KiB
Markdown
74 lines
1.7 KiB
Markdown
![]() |
# đ Device
|
||
|
|
||
|
## đ Table of Contents
|
||
|
|
||
|
- [đ Device](#-device)
|
||
|
- [đ Table of Contents](#-table-of-contents)
|
||
|
- [đ Introduction](#-introduction)
|
||
|
- [đ Design](#-design)
|
||
|
- [đ¨ Usage](#-usage)
|
||
|
|
||
|
## đ Introduction
|
||
|
|
||
|
This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.
|
||
|
|
||
|
## đ Design
|
||
|
|
||
|
|
||
|
This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.
|
||
|
|
||
|
|
||
|
## đ¨ Usage
|
||
|
|
||
|
- Create a device mesh
|
||
|
|
||
|
```python
|
||
|
# this is the list of global ranks involved in the device mesh
|
||
|
# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
|
||
|
physical_mesh_id = torch.arange(4)
|
||
|
mesh_shape = [2, 2]
|
||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||
|
```
|
||
|
|
||
|
- View the mesh
|
||
|
|
||
|
|
||
|
```python
|
||
|
# view the mesh shape
|
||
|
# expect output
|
||
|
# [2, 2]
|
||
|
print(device_mesh.shape)
|
||
|
|
||
|
|
||
|
# view the logical mesh with global ranks
|
||
|
# expect output
|
||
|
# [
|
||
|
# [0, 1],
|
||
|
# [2, 3]
|
||
|
# ]
|
||
|
print(device_mesh.logical_mesh_id)
|
||
|
|
||
|
# view the number of devices in the mesh
|
||
|
# expect output
|
||
|
# 4
|
||
|
print(device_mesh.num_devices)
|
||
|
|
||
|
```
|
||
|
|
||
|
- Initialize the process group
|
||
|
|
||
|
```python
|
||
|
# intialize process group
|
||
|
device_mesh.init_logical_process_group()
|
||
|
|
||
|
|
||
|
# get the process group for a rank with respect to an axis
|
||
|
# this is the process group involving global ranks 0 and 2
|
||
|
print(device_mesh.get_process_group(axis=0, global_rank=0))
|
||
|
|
||
|
# get the ranks in the process with respect to an axis
|
||
|
# expect output
|
||
|
# [0, 2]
|
||
|
print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0))
|
||
|
```
|