Browse Source

[usability] improved error messages in the context module (#856)

pull/867/head
Frank Lee 3 years ago committed by GitHub
parent
commit
2238758c2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 26
      colossalai/context/parallel_context.py
  2. 6
      colossalai/context/process_group_initializer/initializer_2p5d.py

26
colossalai/context/parallel_context.py

@ -4,6 +4,7 @@
import random
import socket
from collections import Counter
from threading import local
from typing import Union
import numpy as np
@ -93,7 +94,8 @@ class ParallelContext(metaclass=SingletonMeta):
@staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode)
assert isinstance(parallel_mode, ParallelMode), \
f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}'
def get_global_rank(self):
"""Returns the global rank of the current device.
@ -133,7 +135,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode)
return self._local_ranks[parallel_mode]
def add_local_rank(self, parallel_mode: ParallelMode, rank: int):
def _add_local_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the local rank of the current device for `parallel_mode` to the context.
Args:
@ -257,11 +259,11 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode)
return self._world_sizes[parallel_mode]
def add_world_size(self, parallel_mode: ParallelMode, world_size: int):
def _add_world_size(self, parallel_mode: ParallelMode, world_size: int):
"""Adds world size for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode correponding to the process group
world_size (int): The world size to be added
Raises:
@ -287,7 +289,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode)
return self._groups[parallel_mode]
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the group of the current device for `parallel_mode`.
Args:
@ -314,7 +316,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode)
return self._cpu_groups[parallel_mode]
def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
@ -343,7 +345,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode)
return self._ranks_in_group[parallel_mode]
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
"""Adds the ranks of the current device for `parallel_mode` in the group.
Args:
@ -378,11 +380,11 @@ class ParallelContext(metaclass=SingletonMeta):
self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size)
self.add_group(mode, process_group)
self.add_cpu_group(mode, cpu_group)
self.add_ranks_in_group(mode, ranks_in_group)
self._add_local_rank(mode, local_rank)
self._add_world_size(mode, world_size)
self._add_group(mode, process_group)
self._add_cpu_group(mode, cpu_group)
self._add_ranks_in_group(mode, ranks_in_group)
def check_sanity(self):
"""Checks sanity of the parallel context.

6
colossalai/context/process_group_initializer/initializer_2p5d.py

@ -105,8 +105,6 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
"""Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
@ -161,8 +159,6 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
"""Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
@ -218,8 +214,6 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
self.num_group = self.world_size // self.tensor_parallel_size
self.tesseract_dep = tesseract_dep
self.tesseract_dim = tesseract_dim
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self):
"""Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.

Loading…
Cancel
Save