Browse Source

[usability] improved error messages in the context module (#856)

pull/867/head
Frank Lee 3 years ago committed by GitHub
parent
commit
2238758c2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 26
      colossalai/context/parallel_context.py
  2. 6
      colossalai/context/process_group_initializer/initializer_2p5d.py

26
colossalai/context/parallel_context.py

@ -4,6 +4,7 @@
import random import random
import socket import socket
from collections import Counter from collections import Counter
from threading import local
from typing import Union from typing import Union
import numpy as np import numpy as np
@ -93,7 +94,8 @@ class ParallelContext(metaclass=SingletonMeta):
@staticmethod @staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode): def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode) assert isinstance(parallel_mode, ParallelMode), \
f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}'
def get_global_rank(self): def get_global_rank(self):
"""Returns the global rank of the current device. """Returns the global rank of the current device.
@ -133,7 +135,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
return self._local_ranks[parallel_mode] return self._local_ranks[parallel_mode]
def add_local_rank(self, parallel_mode: ParallelMode, rank: int): def _add_local_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the local rank of the current device for `parallel_mode` to the context. """Adds the local rank of the current device for `parallel_mode` to the context.
Args: Args:
@ -257,11 +259,11 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
return self._world_sizes[parallel_mode] return self._world_sizes[parallel_mode]
def add_world_size(self, parallel_mode: ParallelMode, world_size: int): def _add_world_size(self, parallel_mode: ParallelMode, world_size: int):
"""Adds world size for `parallel_mode`. """Adds world size for `parallel_mode`.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode correponding to the process group
world_size (int): The world size to be added world_size (int): The world size to be added
Raises: Raises:
@ -287,7 +289,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
return self._groups[parallel_mode] return self._groups[parallel_mode]
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the group of the current device for `parallel_mode`. """Adds the group of the current device for `parallel_mode`.
Args: Args:
@ -314,7 +316,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
return self._cpu_groups[parallel_mode] return self._cpu_groups[parallel_mode]
def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the Gloo group of the current device for `parallel_mode`. """Adds the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode :param parallel_mode: The chosen parallel mode
@ -343,7 +345,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
return self._ranks_in_group[parallel_mode] return self._ranks_in_group[parallel_mode]
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
"""Adds the ranks of the current device for `parallel_mode` in the group. """Adds the ranks of the current device for `parallel_mode` in the group.
Args: Args:
@ -378,11 +380,11 @@ class ParallelContext(metaclass=SingletonMeta):
self.add_global_rank(ParallelMode.GLOBAL, rank) self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode): def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank) self._add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size) self._add_world_size(mode, world_size)
self.add_group(mode, process_group) self._add_group(mode, process_group)
self.add_cpu_group(mode, cpu_group) self._add_cpu_group(mode, cpu_group)
self.add_ranks_in_group(mode, ranks_in_group) self._add_ranks_in_group(mode, ranks_in_group)
def check_sanity(self): def check_sanity(self):
"""Checks sanity of the parallel context. """Checks sanity of the parallel context.

6
colossalai/context/process_group_initializer/initializer_2p5d.py

@ -105,8 +105,6 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
@ -161,8 +159,6 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
@ -218,8 +214,6 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.

Loading…
Cancel
Save