mirror of https://github.com/hpcaitech/ColossalAI
[usability] improved error messages in the context module (#856)
parent
9fdebadd69
commit
2238758c2e
|
@ -4,6 +4,7 @@
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from threading import local
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -93,7 +94,8 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_parallel_mode(parallel_mode: ParallelMode):
|
def _check_parallel_mode(parallel_mode: ParallelMode):
|
||||||
assert isinstance(parallel_mode, ParallelMode)
|
assert isinstance(parallel_mode, ParallelMode), \
|
||||||
|
f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}'
|
||||||
|
|
||||||
def get_global_rank(self):
|
def get_global_rank(self):
|
||||||
"""Returns the global rank of the current device.
|
"""Returns the global rank of the current device.
|
||||||
|
@ -133,7 +135,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
self._check_parallel_mode(parallel_mode)
|
self._check_parallel_mode(parallel_mode)
|
||||||
return self._local_ranks[parallel_mode]
|
return self._local_ranks[parallel_mode]
|
||||||
|
|
||||||
def add_local_rank(self, parallel_mode: ParallelMode, rank: int):
|
def _add_local_rank(self, parallel_mode: ParallelMode, rank: int):
|
||||||
"""Adds the local rank of the current device for `parallel_mode` to the context.
|
"""Adds the local rank of the current device for `parallel_mode` to the context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -257,11 +259,11 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
self._check_parallel_mode(parallel_mode)
|
self._check_parallel_mode(parallel_mode)
|
||||||
return self._world_sizes[parallel_mode]
|
return self._world_sizes[parallel_mode]
|
||||||
|
|
||||||
def add_world_size(self, parallel_mode: ParallelMode, world_size: int):
|
def _add_world_size(self, parallel_mode: ParallelMode, world_size: int):
|
||||||
"""Adds world size for `parallel_mode`.
|
"""Adds world size for `parallel_mode`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
|
parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode correponding to the process group
|
||||||
world_size (int): The world size to be added
|
world_size (int): The world size to be added
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -287,7 +289,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
self._check_parallel_mode(parallel_mode)
|
self._check_parallel_mode(parallel_mode)
|
||||||
return self._groups[parallel_mode]
|
return self._groups[parallel_mode]
|
||||||
|
|
||||||
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
|
def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
|
||||||
"""Adds the group of the current device for `parallel_mode`.
|
"""Adds the group of the current device for `parallel_mode`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -314,7 +316,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
self._check_parallel_mode(parallel_mode)
|
self._check_parallel_mode(parallel_mode)
|
||||||
return self._cpu_groups[parallel_mode]
|
return self._cpu_groups[parallel_mode]
|
||||||
|
|
||||||
def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
|
def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
|
||||||
"""Adds the Gloo group of the current device for `parallel_mode`.
|
"""Adds the Gloo group of the current device for `parallel_mode`.
|
||||||
|
|
||||||
:param parallel_mode: The chosen parallel mode
|
:param parallel_mode: The chosen parallel mode
|
||||||
|
@ -343,7 +345,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
self._check_parallel_mode(parallel_mode)
|
self._check_parallel_mode(parallel_mode)
|
||||||
return self._ranks_in_group[parallel_mode]
|
return self._ranks_in_group[parallel_mode]
|
||||||
|
|
||||||
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
|
def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
|
||||||
"""Adds the ranks of the current device for `parallel_mode` in the group.
|
"""Adds the ranks of the current device for `parallel_mode` in the group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -378,11 +380,11 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
self.add_global_rank(ParallelMode.GLOBAL, rank)
|
self.add_global_rank(ParallelMode.GLOBAL, rank)
|
||||||
|
|
||||||
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
|
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
|
||||||
self.add_local_rank(mode, local_rank)
|
self._add_local_rank(mode, local_rank)
|
||||||
self.add_world_size(mode, world_size)
|
self._add_world_size(mode, world_size)
|
||||||
self.add_group(mode, process_group)
|
self._add_group(mode, process_group)
|
||||||
self.add_cpu_group(mode, cpu_group)
|
self._add_cpu_group(mode, cpu_group)
|
||||||
self.add_ranks_in_group(mode, ranks_in_group)
|
self._add_ranks_in_group(mode, ranks_in_group)
|
||||||
|
|
||||||
def check_sanity(self):
|
def check_sanity(self):
|
||||||
"""Checks sanity of the parallel context.
|
"""Checks sanity of the parallel context.
|
||||||
|
|
|
@ -105,8 +105,6 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
|
||||||
self.num_group = self.world_size // self.tensor_parallel_size
|
self.num_group = self.world_size // self.tensor_parallel_size
|
||||||
self.tesseract_dep = tesseract_dep
|
self.tesseract_dep = tesseract_dep
|
||||||
self.tesseract_dim = tesseract_dim
|
self.tesseract_dim = tesseract_dim
|
||||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
|
||||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
|
|
||||||
|
|
||||||
def init_dist_group(self):
|
def init_dist_group(self):
|
||||||
"""Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
|
"""Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
|
||||||
|
@ -161,8 +159,6 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
||||||
self.num_group = self.world_size // self.tensor_parallel_size
|
self.num_group = self.world_size // self.tensor_parallel_size
|
||||||
self.tesseract_dep = tesseract_dep
|
self.tesseract_dep = tesseract_dep
|
||||||
self.tesseract_dim = tesseract_dim
|
self.tesseract_dim = tesseract_dim
|
||||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
|
||||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
|
|
||||||
|
|
||||||
def init_dist_group(self):
|
def init_dist_group(self):
|
||||||
"""Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
|
"""Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
|
||||||
|
@ -218,8 +214,6 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
|
||||||
self.num_group = self.world_size // self.tensor_parallel_size
|
self.num_group = self.world_size // self.tensor_parallel_size
|
||||||
self.tesseract_dep = tesseract_dep
|
self.tesseract_dep = tesseract_dep
|
||||||
self.tesseract_dim = tesseract_dim
|
self.tesseract_dim = tesseract_dim
|
||||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
|
||||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
|
|
||||||
|
|
||||||
def init_dist_group(self):
|
def init_dist_group(self):
|
||||||
"""Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
|
"""Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
|
||||||
|
|
Loading…
Reference in New Issue