ColossalAI/colossalai/device/README.md

1.7 KiB

🗄 Device

📚 Table of Contents

🔗 Introduction

This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.

📝 Design

This module is inspired by the DeviceMesh in the Alpa project and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.

🔨 Usage

  • Create a device mesh
# this is the list of global ranks involved in the device mesh
# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
physical_mesh_id = torch.arange(4)
mesh_shape = [2, 2]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
  • View the mesh
# view the mesh shape
# expect output
# [2, 2]
print(device_mesh.shape)


# view the logical mesh with global ranks
# expect output
# [
#   [0, 1],
#   [2, 3]
# ]
print(device_mesh.logical_mesh_id)

# view the number of devices in the mesh
# expect output
# 4
print(device_mesh.num_devices)

  • Initialize the process group
# intialize process group
device_mesh.init_logical_process_group()


# get the process group for a rank with respect to an axis
# this is the process group involving global ranks 0 and 2
print(device_mesh.get_process_group(axis=0, global_rank=0))

# get the ranks in the process with respect to an axis
# expect output
# [0, 2]
print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0))