fix format parallel_context.py (#359)

Co-authored-by: huangziyu <202476410arsmart@gmail.com>
pull/394/head
ziyu huang 2022-03-10 09:29:32 +08:00 committed by Frank Lee
parent c695369af0
commit a77d73f22b
1 changed files with 23 additions and 34 deletions

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import random
from typing import Union
@ -20,7 +19,7 @@ from .random import add_seed, get_seeds, set_mode
class ParallelContext:
"""This class provides interface functions for users to get the parallel context,
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
"""
@ -218,7 +217,8 @@ class ParallelContext:
def is_pipeline_last_stage(self, ignore_virtual=False):
if not ignore_virtual:
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
if self.virtual_pipeline_parallel_size \
is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
return False
return self.is_last_rank(ParallelMode.PIPELINE)
@ -300,13 +300,7 @@ class ParallelContext:
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self,
rank: int,
world_size: int,
backend: str,
host: str,
port: int
):
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
@ -321,18 +315,13 @@ class ParallelContext:
"""
# initialize the default process group
init_method = f'tcp://{host}:{port}'
dist.init_process_group(rank=rank,
world_size=world_size,
backend=backend,
init_method=init_method)
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None,
list(range(world_size)), ParallelMode.GLOBAL)
self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size,
process_group, ranks_in_group, mode):
def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size)
self.add_group(mode, process_group)
@ -349,7 +338,9 @@ class ParallelContext:
tps = self.tensor_parallel_size
ws = self.world_size
assert ws == dps * pps * \
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
tps, f"Expected the world size {ws} to be equal to data" \
f" parallel size ({dps}) * pipeline parallel size " \
f"({pps}) * tensor parallel size ({tps})"
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config:
@ -360,8 +351,7 @@ class ParallelContext:
setattr(self, attr_name, ele['size'])
else:
raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict"
)
f'{"Parallel configuration does not support this kind of argument, please use int or dict"}')
def init_parallel_groups(self):
"""Initializes the parallel groups.
@ -386,11 +376,13 @@ class ParallelContext:
# get the tensor parallel mode and check
tensor_parallel_mode = None
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
if parallel_config is not None and 'tensor' in \
parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
assert tensor_parallel_mode in ALLOWED_MODES, \
f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
env.mode = tensor_parallel_mode
self.check_sanity()
pg_init = []
@ -426,12 +418,10 @@ class ParallelContext:
for initializer_cfg in pg_init:
cfg = initializer_cfg.copy()
initializer_type = cfg.pop('type')
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
rank, world_size, self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size,
**cfg)
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size, **cfg)
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):
for args in parallel_setting:
@ -509,10 +499,9 @@ class ParallelContext:
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
if self._verbose:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.")
self._logger.info(f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.")
else:
if self._verbose:
self._logger.info(