mirror of https://github.com/hpcaitech/ColossalAI
[API] implement device mesh manager (#3221)
* [API] implement device mesh manager * polishpull/3194/head
parent
052b03e83f
commit
4d5d8f98a4
|
@ -1,36 +1,117 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeviceMeshInfo:
|
||||||
|
'''
|
||||||
|
This class is used to store the information used to initialize the device mesh.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7].
|
||||||
|
mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2].
|
||||||
|
'''
|
||||||
|
physical_ids: List[int]
|
||||||
|
mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.mesh_shape is not None:
|
||||||
|
world_size = len(self.physical_ids)
|
||||||
|
mesh_shape_numel = torch.Size(self.mesh_shape).numel()
|
||||||
|
assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}'
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_device_mesh(device_mesh_info: DeviceMeshInfo):
|
||||||
|
'''
|
||||||
|
This method is used to initialize the device mesh.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh.
|
||||||
|
'''
|
||||||
|
# parse the device mesh info
|
||||||
|
physical_devices = device_mesh_info.physical_ids
|
||||||
|
physical_mesh = torch.tensor(physical_devices)
|
||||||
|
logical_mesh_shape = device_mesh_info.mesh_shape
|
||||||
|
|
||||||
|
if logical_mesh_shape is None:
|
||||||
|
ab_profiler = AlphaBetaProfiler(physical_devices)
|
||||||
|
# search for the best logical mesh shape
|
||||||
|
logical_mesh_id = ab_profiler.search_best_logical_mesh()
|
||||||
|
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
|
||||||
|
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, logical_mesh_id=logical_mesh_id, init_process_group=True)
|
||||||
|
return device_mesh
|
||||||
|
|
||||||
|
|
||||||
class DeviceMeshManager:
|
class DeviceMeshManager:
|
||||||
"""
|
"""
|
||||||
Device mesh manager is responsible for creating and managing device meshes.
|
Device mesh manager is responsible for creating and managing device meshes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.device_mesh_store = dict()
|
self.device_mesh_store: Dict[str, DeviceMesh] = dict()
|
||||||
|
|
||||||
def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh:
|
def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMesh:
|
||||||
"""
|
"""
|
||||||
Create a device mesh and store it in the manager.
|
Create a device mesh and store it in the manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): name of the device mesh
|
name (str): name of the device mesh
|
||||||
*args: args for DeviceMesh
|
device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh
|
||||||
**kwargs: kwargs for DeviceMesh
|
|
||||||
"""
|
"""
|
||||||
# TODO(Yuliang): replace *args, **kwargs with explicit arguments
|
|
||||||
if name not in self.device_mesh_store:
|
if name not in self.device_mesh_store:
|
||||||
device_mesh = DeviceMesh(*args, **kwargs)
|
device_mesh = initialize_device_mesh(device_mesh_info)
|
||||||
self.device_mesh_store[name] = device_mesh
|
self.device_mesh_store[name] = device_mesh
|
||||||
return device_mesh
|
return device_mesh
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Device mesh {name} already exists.')
|
raise ValueError(f'Device mesh {name} already exists.')
|
||||||
|
|
||||||
def get(self, name: str) -> DeviceMesh:
|
def get(self, name: str) -> DeviceMesh:
|
||||||
pass
|
"""
|
||||||
|
Get a device mesh by name.
|
||||||
|
|
||||||
def destroy(self):
|
Args:
|
||||||
pass
|
name (str): name of the device mesh
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DeviceMesh: the device mesh
|
||||||
|
"""
|
||||||
|
if name in self.device_mesh_store:
|
||||||
|
return self.device_mesh_store[name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Device mesh {name} does not exist.')
|
||||||
|
|
||||||
|
def destroy(self, name: str) -> None:
|
||||||
|
"""
|
||||||
|
Destroy a device mesh by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): name of the device mesh
|
||||||
|
"""
|
||||||
|
if name in self.device_mesh_store:
|
||||||
|
for pgs in self.device_mesh_store[name].process_groups_dict.values():
|
||||||
|
for pg in pgs:
|
||||||
|
dist.destroy_process_group(pg)
|
||||||
|
del self.device_mesh_store[name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Device mesh {name} does not exist.')
|
||||||
|
|
||||||
def destroy_all(self):
|
def destroy_all(self):
|
||||||
pass
|
"""
|
||||||
|
Destroy all device meshes.
|
||||||
|
"""
|
||||||
|
for name in self.device_mesh_store:
|
||||||
|
for pgs in self.device_mesh_store[name].process_groups_dict.values():
|
||||||
|
for pg in pgs:
|
||||||
|
dist.destroy_process_group(pg)
|
||||||
|
|
||||||
|
self.device_mesh_store.clear()
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.tracer import ColoTracer
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
|
def check_device_mesh_manager(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
device_mesh_manager = DeviceMeshManager()
|
||||||
|
device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],)
|
||||||
|
device_mesh_auto = device_mesh_manager.create_device_mesh('0', device_mesh_info_auto)
|
||||||
|
assert device_mesh_auto.shape == (2, 2)
|
||||||
|
assert device_mesh_auto._logical_mesh_id.tolist() == [[0, 1], [2, 3]]
|
||||||
|
|
||||||
|
device_mesh_info_with_shape = DeviceMeshInfo(
|
||||||
|
physical_ids=[0, 1, 2, 3],
|
||||||
|
mesh_shape=(2, 2),
|
||||||
|
)
|
||||||
|
device_mesh_with_shape = device_mesh_manager.create_device_mesh('1', device_mesh_info_with_shape)
|
||||||
|
|
||||||
|
assert device_mesh_with_shape.shape == (2, 2)
|
||||||
|
assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_device_mesh_manager():
|
||||||
|
world_size = 4
|
||||||
|
run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_device_mesh_manager()
|
Loading…
Reference in New Issue