Browse Source

fix-test (#5210)

fix-test

fix-test
pull/4976/merge
flybird11111 11 months ago committed by GitHub
parent
commit
365671be10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      colossalai/booster/plugin/gemini_plugin.py
  2. 4
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  3. 2
      colossalai/cluster/process_group_mesh.py

4
colossalai/booster/plugin/gemini_plugin.py

@ -437,6 +437,10 @@ class GeminiPlugin(DPPluginBase):
enable_sequence_overlap=self.enable_sequence_overlap,
)
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
def support_no_sync(self) -> bool:
return False

4
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -1054,6 +1054,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.max_norm = max_norm
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
@property
def enable_pipeline_parallelism(self) -> bool:
return self.pp_size > 1

2
colossalai/cluster/process_group_mesh.py

@ -45,7 +45,7 @@ class ProcessGroupMesh:
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def __del__(self):
def destroy_mesh_process_groups(self):
r"""
Destructor method for the ProcessGroupMesh class.

Loading…
Cancel
Save