diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index b81c0b452..379497b48 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import os import random from typing import Union @@ -20,7 +19,7 @@ from .random import add_seed, get_seeds, set_mode class ParallelContext: - """This class provides interface functions for users to get the parallel context, + """This class provides interface functions for users to get the parallel context, such as the global rank, the local rank, the world size, etc. of each device. """ @@ -218,7 +217,8 @@ class ParallelContext: def is_pipeline_last_stage(self, ignore_virtual=False): if not ignore_virtual: - if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: + if self.virtual_pipeline_parallel_size \ + is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: return False return self.is_last_rank(ParallelMode.PIPELINE) @@ -300,13 +300,7 @@ class ParallelContext: self._check_parallel_mode(parallel_mode) self._ranks_in_group[parallel_mode] = ranks - def init_global_dist(self, - rank: int, - world_size: int, - backend: str, - host: str, - port: int - ): + def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int): """Initializes the global distributed environment :param rank: rank for the default process group :type rank: int @@ -321,18 +315,13 @@ class ParallelContext: """ # initialize the default process group init_method = f'tcp://{host}:{port}' - dist.init_process_group(rank=rank, - world_size=world_size, - backend=backend, - init_method=init_method) + dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # None will give the default global process group for pytorch dist operations - self._register_dist(rank, world_size, None, - list(range(world_size)), ParallelMode.GLOBAL) + self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL) self.add_global_rank(ParallelMode.GLOBAL, rank) - def _register_dist(self, local_rank, world_size, - process_group, ranks_in_group, mode): + def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode): self.add_local_rank(mode, local_rank) self.add_world_size(mode, world_size) self.add_group(mode, process_group) @@ -349,7 +338,9 @@ class ParallelContext: tps = self.tensor_parallel_size ws = self.world_size assert ws == dps * pps * \ - tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})" + tps, f"Expected the world size {ws} to be equal to data" \ + f" parallel size ({dps}) * pipeline parallel size " \ + f"({pps}) * tensor parallel size ({tps})" def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): if key in config: @@ -360,8 +351,7 @@ class ParallelContext: setattr(self, attr_name, ele['size']) else: raise NotImplementedError( - f"Parallel configuration does not support this kind of argument, please use int or dict" - ) + f'{"Parallel configuration does not support this kind of argument, please use int or dict"}') def init_parallel_groups(self): """Initializes the parallel groups. @@ -386,11 +376,13 @@ class ParallelContext: # get the tensor parallel mode and check tensor_parallel_mode = None - if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']: + if parallel_config is not None and 'tensor' in \ + parallel_config and 'mode' in parallel_config['tensor']: tensor_parallel_mode = parallel_config['tensor']['mode'] - assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" + assert tensor_parallel_mode in ALLOWED_MODES, \ + f"mode in the parallel config must be set to one of {ALLOWED_MODES}" env.mode = tensor_parallel_mode - + self.check_sanity() pg_init = [] @@ -426,12 +418,10 @@ class ParallelContext: for initializer_cfg in pg_init: cfg = initializer_cfg.copy() initializer_type = cfg.pop('type') - initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)( - rank, world_size, self.config, - self.data_parallel_size, - self.pipeline_parallel_size, - self.tensor_parallel_size, - **cfg) + initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config, + self.data_parallel_size, + self.pipeline_parallel_size, + self.tensor_parallel_size, **cfg) parallel_setting = initializer.init_dist_group() if isinstance(parallel_setting, list): for args in parallel_setting: @@ -509,10 +499,9 @@ class ParallelContext: seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) if self._verbose: - self._logger.info( - f"initialized seed on rank {global_rank}, " - f"numpy: {seed}, python random: {seed}, {seed_str}," - f"the default parallel seed is {ParallelMode.DATA}.") + self._logger.info(f"initialized seed on rank {global_rank}, " + f"numpy: {seed}, python random: {seed}, {seed_str}," + f"the default parallel seed is {ParallelMode.DATA}.") else: if self._verbose: self._logger.info(