diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index b77fe5eef..7596a100b 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -24,6 +24,7 @@ class DeviceMesh: during initializing the DeviceMesh instance if the init_process_group set to True. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. (default: False) + need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. """ def __init__(self, @@ -50,7 +51,7 @@ class DeviceMesh: self.need_flatten = need_flatten if self.init_process_group: self.process_groups_dict = self.create_process_groups_for_logical_mesh() - if self.need_flatten: + if self.need_flatten and self._logical_mesh_id.dim() > 1: self.flatten_device_mesh = self.flatten() # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,