mirror of https://github.com/hpcaitech/ColossalAI
[device] update flatten device mesh usage (#2079)
parent
a7adad9ccb
commit
677e1e20d4
|
@ -24,6 +24,7 @@ class DeviceMesh:
|
|||
during initializing the DeviceMesh instance if the init_process_group set to True.
|
||||
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
|
||||
(default: False)
|
||||
need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -50,7 +51,7 @@ class DeviceMesh:
|
|||
self.need_flatten = need_flatten
|
||||
if self.init_process_group:
|
||||
self.process_groups_dict = self.create_process_groups_for_logical_mesh()
|
||||
if self.need_flatten:
|
||||
if self.need_flatten and self._logical_mesh_id.dim() > 1:
|
||||
self.flatten_device_mesh = self.flatten()
|
||||
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
|
||||
self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
|
||||
|
|
Loading…
Reference in New Issue