fix format parallel_context.py (#359)

Co-authored-by: huangziyu <202476410arsmart@gmail.com>
pull/394/head
ziyu huang 2022-03-10 09:29:32 +08:00 committed by Frank Lee
parent c695369af0
commit a77d73f22b
1 changed files with 23 additions and 34 deletions

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import random
from typing import Union
@ -218,7 +217,8 @@ class ParallelContext:
def is_pipeline_last_stage(self, ignore_virtual=False):
if not ignore_virtual:
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
if self.virtual_pipeline_parallel_size \
is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
return False
return self.is_last_rank(ParallelMode.PIPELINE)
@ -300,13 +300,7 @@ class ParallelContext:
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self,
rank: int,
world_size: int,
backend: str,
host: str,
port: int
):
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
@ -321,18 +315,13 @@ class ParallelContext:
"""
# initialize the default process group
init_method = f'tcp://{host}:{port}'
dist.init_process_group(rank=rank,
world_size=world_size,
backend=backend,
init_method=init_method)
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None,
list(range(world_size)), ParallelMode.GLOBAL)
self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size,
process_group, ranks_in_group, mode):
def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size)
self.add_group(mode, process_group)
@ -349,7 +338,9 @@ class ParallelContext:
tps = self.tensor_parallel_size
ws = self.world_size
assert ws == dps * pps * \
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
tps, f"Expected the world size {ws} to be equal to data" \
f" parallel size ({dps}) * pipeline parallel size " \
f"({pps}) * tensor parallel size ({tps})"
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config:
@ -360,8 +351,7 @@ class ParallelContext:
setattr(self, attr_name, ele['size'])
else:
raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict"
)
f'{"Parallel configuration does not support this kind of argument, please use int or dict"}')
def init_parallel_groups(self):
"""Initializes the parallel groups.
@ -386,9 +376,11 @@ class ParallelContext:
# get the tensor parallel mode and check
tensor_parallel_mode = None
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
if parallel_config is not None and 'tensor' in \
parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
assert tensor_parallel_mode in ALLOWED_MODES, \
f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
env.mode = tensor_parallel_mode
self.check_sanity()
@ -426,12 +418,10 @@ class ParallelContext:
for initializer_cfg in pg_init:
cfg = initializer_cfg.copy()
initializer_type = cfg.pop('type')
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
rank, world_size, self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size,
**cfg)
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size, **cfg)
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):
for args in parallel_setting:
@ -509,10 +499,9 @@ class ParallelContext:
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
if self._verbose:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.")
self._logger.info(f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.")
else:
if self._verbose:
self._logger.info(