mirror of https://github.com/hpcaitech/ColossalAI
Fix the bug where process groups were not being properly released.
parent
11009103be
commit
d480cce472
|
@ -1,3 +1,4 @@
|
||||||
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from operator import mul
|
from operator import mul
|
||||||
|
@ -44,6 +45,24 @@ class ProcessGroupMesh:
|
||||||
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
|
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
|
||||||
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
|
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
r"""
|
||||||
|
Destructor method for the ProcessGroupMesh class.
|
||||||
|
|
||||||
|
When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
|
||||||
|
cleaning up any process groups that were created during the lifetime of the object.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
|
||||||
|
when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
|
||||||
|
system resources.
|
||||||
|
"""
|
||||||
|
for group in self._ranks_to_group.values():
|
||||||
|
dist.destroy_process_group(group)
|
||||||
|
|
||||||
|
# Manually clear all process groups to save memory
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Tuple[int, ...]:
|
def shape(self) -> Tuple[int, ...]:
|
||||||
return self._shape
|
return self._shape
|
||||||
|
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.context.singleton_meta import SingletonMeta
|
from colossalai.context.singleton_meta import SingletonMeta
|
||||||
from colossalai.tensor.d_tensor.comm_spec import *
|
from colossalai.tensor.d_tensor.comm_spec import *
|
||||||
|
@ -438,11 +439,58 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||||
MAX_TRANSFORM_STEPS = 20
|
MAX_TRANSFORM_STEPS = 20
|
||||||
total_steps = 0
|
total_steps = 0
|
||||||
transform_path = []
|
transform_path = []
|
||||||
comm_action_sequence = []
|
comm_action_sequence: List[CommSpec] = []
|
||||||
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
|
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
|
||||||
|
|
||||||
if spec_pairs in self.cached_solution:
|
if spec_pairs in self.cached_solution:
|
||||||
return self.cached_solution[spec_pairs]
|
# Solution Cache hit
|
||||||
|
|
||||||
|
def _group_alive_check(cached_comm_action_sequence):
|
||||||
|
r"""
|
||||||
|
Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method.
|
||||||
|
If not deleted, return True; otherwise, return False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all process groups are still registered, False if at least one has been deleted.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If there is an error while checking the status of a process group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Collect all process groups used in communication actions from the cached sequence
|
||||||
|
used_process_groups = [
|
||||||
|
pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check if each process group is still alive
|
||||||
|
for process_group in used_process_groups:
|
||||||
|
try:
|
||||||
|
dist.get_rank(process_group)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# If the group is not registered, it means it has been deleted
|
||||||
|
if str(e) == (
|
||||||
|
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
elif str(e) == "The given group does not exist":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Re-raise the exception if it's not related to group deletion
|
||||||
|
raise e
|
||||||
|
# All process groups are alive
|
||||||
|
return True
|
||||||
|
|
||||||
|
cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs]
|
||||||
|
|
||||||
|
if _group_alive_check(cached_comm_action_sequence):
|
||||||
|
# If all process groups have not been deleted, the cache is valid
|
||||||
|
return cached_transform_path, cached_comm_action_sequence
|
||||||
|
else:
|
||||||
|
# If at least one process group has been deleted, the cache is invalid, so delete it
|
||||||
|
del self.cached_solution[spec_pairs]
|
||||||
|
|
||||||
# We do nothing if the sharding spec is all the same.
|
# We do nothing if the sharding spec is all the same.
|
||||||
if source_spec.spec_diff(target_spec) == 0:
|
if source_spec.spec_diff(target_spec) == 0:
|
||||||
|
|
Loading…
Reference in New Issue