mirror of https://github.com/hpcaitech/ColossalAI
![]() |
||
---|---|---|
.. | ||
README.md | ||
__init__.py | ||
alpha_beta_profiler.py | ||
calc_pipeline_strategy.py | ||
device_mesh.py |
README.md
đ Device
đ Table of Contents
đ Introduction
This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.
đ Design
This module is inspired by the DeviceMesh in the Alpa project and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.
đ¨ Usage
- Create a device mesh
# this is the list of global ranks involved in the device mesh
# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
physical_mesh_id = torch.arange(4)
mesh_shape = [2, 2]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- View the mesh
# view the mesh shape
# expect output
# [2, 2]
print(device_mesh.shape)
# view the logical mesh with global ranks
# expect output
# [
# [0, 1],
# [2, 3]
# ]
print(device_mesh.logical_mesh_id)
# view the number of devices in the mesh
# expect output
# 4
print(device_mesh.num_devices)
- Initialize the process group
# intialize process group
device_mesh.init_logical_process_group()
# get the process group for a rank with respect to an axis
# this is the process group involving global ranks 0 and 2
print(device_mesh.get_process_group(axis=0, global_rank=0))
# get the ranks in the process with respect to an axis
# expect output
# [0, 2]
print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0))