mirror of https://github.com/hpcaitech/ColossalAI
fix format parallel_context.py (#359)
Co-authored-by: huangziyu <202476410arsmart@gmail.com>pull/394/head
parent
c695369af0
commit
a77d73f22b
|
@ -1,7 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
|
@ -20,7 +19,7 @@ from .random import add_seed, get_seeds, set_mode
|
|||
|
||||
|
||||
class ParallelContext:
|
||||
"""This class provides interface functions for users to get the parallel context,
|
||||
"""This class provides interface functions for users to get the parallel context,
|
||||
such as the global rank, the local rank, the world size, etc. of each device.
|
||||
|
||||
"""
|
||||
|
@ -218,7 +217,8 @@ class ParallelContext:
|
|||
|
||||
def is_pipeline_last_stage(self, ignore_virtual=False):
|
||||
if not ignore_virtual:
|
||||
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
|
||||
if self.virtual_pipeline_parallel_size \
|
||||
is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
|
||||
return False
|
||||
return self.is_last_rank(ParallelMode.PIPELINE)
|
||||
|
||||
|
@ -300,13 +300,7 @@ class ParallelContext:
|
|||
self._check_parallel_mode(parallel_mode)
|
||||
self._ranks_in_group[parallel_mode] = ranks
|
||||
|
||||
def init_global_dist(self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
host: str,
|
||||
port: int
|
||||
):
|
||||
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
|
||||
"""Initializes the global distributed environment
|
||||
:param rank: rank for the default process group
|
||||
:type rank: int
|
||||
|
@ -321,18 +315,13 @@ class ParallelContext:
|
|||
"""
|
||||
# initialize the default process group
|
||||
init_method = f'tcp://{host}:{port}'
|
||||
dist.init_process_group(rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
init_method=init_method)
|
||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
|
||||
# None will give the default global process group for pytorch dist operations
|
||||
self._register_dist(rank, world_size, None,
|
||||
list(range(world_size)), ParallelMode.GLOBAL)
|
||||
self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
|
||||
self.add_global_rank(ParallelMode.GLOBAL, rank)
|
||||
|
||||
def _register_dist(self, local_rank, world_size,
|
||||
process_group, ranks_in_group, mode):
|
||||
def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
|
||||
self.add_local_rank(mode, local_rank)
|
||||
self.add_world_size(mode, world_size)
|
||||
self.add_group(mode, process_group)
|
||||
|
@ -349,7 +338,9 @@ class ParallelContext:
|
|||
tps = self.tensor_parallel_size
|
||||
ws = self.world_size
|
||||
assert ws == dps * pps * \
|
||||
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
|
||||
tps, f"Expected the world size {ws} to be equal to data" \
|
||||
f" parallel size ({dps}) * pipeline parallel size " \
|
||||
f"({pps}) * tensor parallel size ({tps})"
|
||||
|
||||
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
||||
if key in config:
|
||||
|
@ -360,8 +351,7 @@ class ParallelContext:
|
|||
setattr(self, attr_name, ele['size'])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Parallel configuration does not support this kind of argument, please use int or dict"
|
||||
)
|
||||
f'{"Parallel configuration does not support this kind of argument, please use int or dict"}')
|
||||
|
||||
def init_parallel_groups(self):
|
||||
"""Initializes the parallel groups.
|
||||
|
@ -386,11 +376,13 @@ class ParallelContext:
|
|||
|
||||
# get the tensor parallel mode and check
|
||||
tensor_parallel_mode = None
|
||||
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
|
||||
if parallel_config is not None and 'tensor' in \
|
||||
parallel_config and 'mode' in parallel_config['tensor']:
|
||||
tensor_parallel_mode = parallel_config['tensor']['mode']
|
||||
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
||||
assert tensor_parallel_mode in ALLOWED_MODES, \
|
||||
f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
|
||||
env.mode = tensor_parallel_mode
|
||||
|
||||
|
||||
self.check_sanity()
|
||||
|
||||
pg_init = []
|
||||
|
@ -426,12 +418,10 @@ class ParallelContext:
|
|||
for initializer_cfg in pg_init:
|
||||
cfg = initializer_cfg.copy()
|
||||
initializer_type = cfg.pop('type')
|
||||
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
|
||||
rank, world_size, self.config,
|
||||
self.data_parallel_size,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
**cfg)
|
||||
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config,
|
||||
self.data_parallel_size,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size, **cfg)
|
||||
parallel_setting = initializer.init_dist_group()
|
||||
if isinstance(parallel_setting, list):
|
||||
for args in parallel_setting:
|
||||
|
@ -509,10 +499,9 @@ class ParallelContext:
|
|||
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
|
||||
|
||||
if self._verbose:
|
||||
self._logger.info(
|
||||
f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||
f"the default parallel seed is {ParallelMode.DATA}.")
|
||||
self._logger.info(f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||
f"the default parallel seed is {ParallelMode.DATA}.")
|
||||
else:
|
||||
if self._verbose:
|
||||
self._logger.info(
|
||||
|
|
Loading…
Reference in New Issue