|
|
|
@ -4,6 +4,7 @@
|
|
|
|
|
import random |
|
|
|
|
import socket |
|
|
|
|
from collections import Counter |
|
|
|
|
from threading import local |
|
|
|
|
from typing import Union |
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
@ -93,7 +94,8 @@ class ParallelContext(metaclass=SingletonMeta):
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _check_parallel_mode(parallel_mode: ParallelMode): |
|
|
|
|
assert isinstance(parallel_mode, ParallelMode) |
|
|
|
|
assert isinstance(parallel_mode, ParallelMode), \ |
|
|
|
|
f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}' |
|
|
|
|
|
|
|
|
|
def get_global_rank(self): |
|
|
|
|
"""Returns the global rank of the current device. |
|
|
|
@ -133,7 +135,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|
|
|
|
self._check_parallel_mode(parallel_mode) |
|
|
|
|
return self._local_ranks[parallel_mode] |
|
|
|
|
|
|
|
|
|
def add_local_rank(self, parallel_mode: ParallelMode, rank: int): |
|
|
|
|
def _add_local_rank(self, parallel_mode: ParallelMode, rank: int): |
|
|
|
|
"""Adds the local rank of the current device for `parallel_mode` to the context. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -257,11 +259,11 @@ class ParallelContext(metaclass=SingletonMeta):
|
|
|
|
|
self._check_parallel_mode(parallel_mode) |
|
|
|
|
return self._world_sizes[parallel_mode] |
|
|
|
|
|
|
|
|
|
def add_world_size(self, parallel_mode: ParallelMode, world_size: int): |
|
|
|
|
def _add_world_size(self, parallel_mode: ParallelMode, world_size: int): |
|
|
|
|
"""Adds world size for `parallel_mode`. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. |
|
|
|
|
parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode correponding to the process group |
|
|
|
|
world_size (int): The world size to be added |
|
|
|
|
|
|
|
|
|
Raises: |
|
|
|
@ -287,7 +289,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|
|
|
|
self._check_parallel_mode(parallel_mode) |
|
|
|
|
return self._groups[parallel_mode] |
|
|
|
|
|
|
|
|
|
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): |
|
|
|
|
def _add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): |
|
|
|
|
"""Adds the group of the current device for `parallel_mode`. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -314,7 +316,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|
|
|
|
self._check_parallel_mode(parallel_mode) |
|
|
|
|
return self._cpu_groups[parallel_mode] |
|
|
|
|
|
|
|
|
|
def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): |
|
|
|
|
def _add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): |
|
|
|
|
"""Adds the Gloo group of the current device for `parallel_mode`. |
|
|
|
|
|
|
|
|
|
:param parallel_mode: The chosen parallel mode |
|
|
|
@ -343,7 +345,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|
|
|
|
self._check_parallel_mode(parallel_mode) |
|
|
|
|
return self._ranks_in_group[parallel_mode] |
|
|
|
|
|
|
|
|
|
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): |
|
|
|
|
def _add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): |
|
|
|
|
"""Adds the ranks of the current device for `parallel_mode` in the group. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -378,11 +380,11 @@ class ParallelContext(metaclass=SingletonMeta):
|
|
|
|
|
self.add_global_rank(ParallelMode.GLOBAL, rank) |
|
|
|
|
|
|
|
|
|
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode): |
|
|
|
|
self.add_local_rank(mode, local_rank) |
|
|
|
|
self.add_world_size(mode, world_size) |
|
|
|
|
self.add_group(mode, process_group) |
|
|
|
|
self.add_cpu_group(mode, cpu_group) |
|
|
|
|
self.add_ranks_in_group(mode, ranks_in_group) |
|
|
|
|
self._add_local_rank(mode, local_rank) |
|
|
|
|
self._add_world_size(mode, world_size) |
|
|
|
|
self._add_group(mode, process_group) |
|
|
|
|
self._add_cpu_group(mode, cpu_group) |
|
|
|
|
self._add_ranks_in_group(mode, ranks_in_group) |
|
|
|
|
|
|
|
|
|
def check_sanity(self): |
|
|
|
|
"""Checks sanity of the parallel context. |
|
|
|
|