mirror of https://github.com/hpcaitech/ColossalAI
Fix the bug where process groups were not being properly released.
parent
11009103be
commit
d480cce472
|
@ -1,3 +1,4 @@
|
|||
import gc
|
||||
import itertools
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
@ -44,6 +45,24 @@ class ProcessGroupMesh:
|
|||
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
|
||||
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
|
||||
|
||||
def __del__(self):
|
||||
r"""
|
||||
Destructor method for the ProcessGroupMesh class.
|
||||
|
||||
When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
|
||||
cleaning up any process groups that were created during the lifetime of the object.
|
||||
|
||||
Note:
|
||||
All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
|
||||
when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
|
||||
system resources.
|
||||
"""
|
||||
for group in self._ranks_to_group.values():
|
||||
dist.destroy_process_group(group)
|
||||
|
||||
# Manually clear all process groups to save memory
|
||||
gc.collect()
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
return self._shape
|
||||
|
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.d_tensor.comm_spec import *
|
||||
|
@ -438,11 +439,58 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
MAX_TRANSFORM_STEPS = 20
|
||||
total_steps = 0
|
||||
transform_path = []
|
||||
comm_action_sequence = []
|
||||
comm_action_sequence: List[CommSpec] = []
|
||||
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
|
||||
|
||||
if spec_pairs in self.cached_solution:
|
||||
return self.cached_solution[spec_pairs]
|
||||
# Solution Cache hit
|
||||
|
||||
def _group_alive_check(cached_comm_action_sequence):
|
||||
r"""
|
||||
Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method.
|
||||
If not deleted, return True; otherwise, return False.
|
||||
|
||||
Args:
|
||||
cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions.
|
||||
|
||||
Returns:
|
||||
bool: True if all process groups are still registered, False if at least one has been deleted.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If there is an error while checking the status of a process group.
|
||||
"""
|
||||
|
||||
# Collect all process groups used in communication actions from the cached sequence
|
||||
used_process_groups = [
|
||||
pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values()
|
||||
]
|
||||
|
||||
# Check if each process group is still alive
|
||||
for process_group in used_process_groups:
|
||||
try:
|
||||
dist.get_rank(process_group)
|
||||
except RuntimeError as e:
|
||||
# If the group is not registered, it means it has been deleted
|
||||
if str(e) == (
|
||||
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
|
||||
):
|
||||
return False
|
||||
elif str(e) == "The given group does not exist":
|
||||
return False
|
||||
else:
|
||||
# Re-raise the exception if it's not related to group deletion
|
||||
raise e
|
||||
# All process groups are alive
|
||||
return True
|
||||
|
||||
cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs]
|
||||
|
||||
if _group_alive_check(cached_comm_action_sequence):
|
||||
# If all process groups have not been deleted, the cache is valid
|
||||
return cached_transform_path, cached_comm_action_sequence
|
||||
else:
|
||||
# If at least one process group has been deleted, the cache is invalid, so delete it
|
||||
del self.cached_solution[spec_pairs]
|
||||
|
||||
# We do nothing if the sharding spec is all the same.
|
||||
if source_spec.spec_diff(target_spec) == 0:
|
||||
|
|
Loading…
Reference in New Issue