From 2238758c2e850a206698a0d65e19c1c455297e61 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 25 Apr 2022 13:42:31 +0800 Subject: [PATCH] [usability] improved error messages in the context module (#856) --- colossalai/context/parallel_context.py | 26 ++++++++++--------- .../initializer_2p5d.py | 6 ----- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index 6102e701a..afa306065 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -4,6 +4,7 @@ import random import socket from collections import Counter +from threading import local from typing import Union import numpy as np @@ -93,7 +94,8 @@ class ParallelContext(metaclass=SingletonMeta): @staticmethod def _check_parallel_mode(parallel_mode: ParallelMode): - assert isinstance(parallel_mode, ParallelMode) + assert isinstance(parallel_mode, ParallelMode), \ + f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}' def get_global_rank(self): """Returns the global rank of the current device. @@ -133,7 +135,7 @@ class ParallelContext(metaclass=SingletonMeta): self._check_parallel_mode(parallel_mode) return self._local_ranks[parallel_mode] - def add_local_rank(self, parallel_mode: ParallelMode, rank: int): + def _add_local_rank(self, parallel_mode: ParallelMode, rank: int): """Adds the local rank of the current device for `parallel_mode` to the context. Args: @@ -257,11 +259,11 @@ class ParallelContext(metaclass=SingletonMeta): self._check_parallel_mode(parallel_mode) return self._world_sizes[parallel_mode] - def add_world_size(self, parallel_mode: ParallelMode, world_size: int): + def _add_world_size(self, parallel_mode: ParallelMode, world_size: int): """Adds world size for `parallel_mode`. Args: - parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. + parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode correponding to the process group world_size (int): The world size to be added Raises: @@ -287,7 +289,7 @@ class ParallelContext(metaclass=SingletonMeta): self._check_parallel_mode(parallel_mode) return self._groups[parallel_mode] - def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): + def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): """Adds the group of the current device for `parallel_mode`. Args: @@ -314,7 +316,7 @@ class ParallelContext(metaclass=SingletonMeta): self._check_parallel_mode(parallel_mode) return self._cpu_groups[parallel_mode] - def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): + def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): """Adds the Gloo group of the current device for `parallel_mode`. :param parallel_mode: The chosen parallel mode @@ -343,7 +345,7 @@ class ParallelContext(metaclass=SingletonMeta): self._check_parallel_mode(parallel_mode) return self._ranks_in_group[parallel_mode] - def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): + def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): """Adds the ranks of the current device for `parallel_mode` in the group. Args: @@ -378,11 +380,11 @@ class ParallelContext(metaclass=SingletonMeta): self.add_global_rank(ParallelMode.GLOBAL, rank) def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode): - self.add_local_rank(mode, local_rank) - self.add_world_size(mode, world_size) - self.add_group(mode, process_group) - self.add_cpu_group(mode, cpu_group) - self.add_ranks_in_group(mode, ranks_in_group) + self._add_local_rank(mode, local_rank) + self._add_world_size(mode, world_size) + self._add_group(mode, process_group) + self._add_cpu_group(mode, cpu_group) + self._add_ranks_in_group(mode, ranks_in_group) def check_sanity(self): """Checks sanity of the parallel context. diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py index 457361ab4..6b6fdc5d7 100644 --- a/colossalai/context/process_group_initializer/initializer_2p5d.py +++ b/colossalai/context/process_group_initializer/initializer_2p5d.py @@ -105,8 +105,6 @@ class Initializer_2p5D_Col(ProcessGroupInitializer): self.num_group = self.world_size // self.tensor_parallel_size self.tesseract_dep = tesseract_dep self.tesseract_dim = tesseract_dim - assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ - "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" def init_dist_group(self): """Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu. @@ -161,8 +159,6 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer): self.num_group = self.world_size // self.tensor_parallel_size self.tesseract_dep = tesseract_dep self.tesseract_dim = tesseract_dim - assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ - "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" def init_dist_group(self): """Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu. @@ -218,8 +214,6 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer): self.num_group = self.world_size // self.tensor_parallel_size self.tesseract_dep = tesseract_dep self.tesseract_dim = tesseract_dim - assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ - "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" def init_dist_group(self): """Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.