Fix the bug where process groups were not being properly released.

pull/4940/head
littsk 2023-10-23 12:24:02 +08:00
parent 11009103be
commit d480cce472
2 changed files with 69 additions and 2 deletions

View File

@ -1,3 +1,4 @@
import gc
import itertools
from functools import reduce
from operator import mul
@ -44,6 +45,24 @@ class ProcessGroupMesh:
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def __del__(self):
r"""
Destructor method for the ProcessGroupMesh class.
When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
cleaning up any process groups that were created during the lifetime of the object.
Note:
All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
system resources.
"""
for group in self._ranks_to_group.values():
dist.destroy_process_group(group)
# Manually clear all process groups to save memory
gc.collect()
@property
def shape(self) -> Tuple[int, ...]:
return self._shape

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
@ -438,11 +439,58 @@ class LayoutConverter(metaclass=SingletonMeta):
MAX_TRANSFORM_STEPS = 20
total_steps = 0
transform_path = []
comm_action_sequence = []
comm_action_sequence: List[CommSpec] = []
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
if spec_pairs in self.cached_solution:
return self.cached_solution[spec_pairs]
# Solution Cache hit
def _group_alive_check(cached_comm_action_sequence):
r"""
Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method.
If not deleted, return True; otherwise, return False.
Args:
cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions.
Returns:
bool: True if all process groups are still registered, False if at least one has been deleted.
Raises:
RuntimeError: If there is an error while checking the status of a process group.
"""
# Collect all process groups used in communication actions from the cached sequence
used_process_groups = [
pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values()
]
# Check if each process group is still alive
for process_group in used_process_groups:
try:
dist.get_rank(process_group)
except RuntimeError as e:
# If the group is not registered, it means it has been deleted
if str(e) == (
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
):
return False
elif str(e) == "The given group does not exist":
return False
else:
# Re-raise the exception if it's not related to group deletion
raise e
# All process groups are alive
return True
cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs]
if _group_alive_check(cached_comm_action_sequence):
# If all process groups have not been deleted, the cache is valid
return cached_transform_path, cached_comm_action_sequence
else:
# If at least one process group has been deleted, the cache is invalid, so delete it
del self.cached_solution[spec_pairs]
# We do nothing if the sharding spec is all the same.
if source_spec.spec_diff(target_spec) == 0: