From d480cce472eb344345e5f3b78c9709ee7f5b8985 Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Mon, 23 Oct 2023 12:24:02 +0800 Subject: [PATCH] Fix the bug where process groups were not being properly released. --- colossalai/cluster/process_group_mesh.py | 19 +++++++ .../tensor/d_tensor/layout_converter.py | 52 ++++++++++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 3885bc962..eb4532194 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -1,3 +1,4 @@ +import gc import itertools from functools import reduce from operator import mul @@ -44,6 +45,24 @@ class ProcessGroupMesh: self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} + def __del__(self): + r""" + Destructor method for the ProcessGroupMesh class. + + When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for + cleaning up any process groups that were created during the lifetime of the object. + + Note: + All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed + when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release + system resources. + """ + for group in self._ranks_to_group.values(): + dist.destroy_process_group(group) + + # Manually clear all process groups to save memory + gc.collect() + @property def shape(self) -> Tuple[int, ...]: return self._shape diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index e031e0472..abe4a86d8 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Dict, List, Tuple import torch +import torch.distributed as dist from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * @@ -438,11 +439,58 @@ class LayoutConverter(metaclass=SingletonMeta): MAX_TRANSFORM_STEPS = 20 total_steps = 0 transform_path = [] - comm_action_sequence = [] + comm_action_sequence: List[CommSpec] = [] spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) if spec_pairs in self.cached_solution: - return self.cached_solution[spec_pairs] + # Solution Cache hit + + def _group_alive_check(cached_comm_action_sequence): + r""" + Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method. + If not deleted, return True; otherwise, return False. + + Args: + cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions. + + Returns: + bool: True if all process groups are still registered, False if at least one has been deleted. + + Raises: + RuntimeError: If there is an error while checking the status of a process group. + """ + + # Collect all process groups used in communication actions from the cached sequence + used_process_groups = [ + pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values() + ] + + # Check if each process group is still alive + for process_group in used_process_groups: + try: + dist.get_rank(process_group) + except RuntimeError as e: + # If the group is not registered, it means it has been deleted + if str(e) == ( + f"Group {process_group} is not registered, please create group with torch.distributed.new_group API" + ): + return False + elif str(e) == "The given group does not exist": + return False + else: + # Re-raise the exception if it's not related to group deletion + raise e + # All process groups are alive + return True + + cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs] + + if _group_alive_check(cached_comm_action_sequence): + # If all process groups have not been deleted, the cache is valid + return cached_transform_path, cached_comm_action_sequence + else: + # If at least one process group has been deleted, the cache is invalid, so delete it + del self.cached_solution[spec_pairs] # We do nothing if the sharding spec is all the same. if source_spec.spec_diff(target_spec) == 0: