From 04ff5ea546086aaabbb98873e77a7ce624b097c3 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 12 Apr 2022 09:28:19 +0800 Subject: [PATCH] [utils] support detection of number of processes on current node (#723) --- colossalai/context/parallel_context.py | 10 ++++++++++ colossalai/initialize.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index 959eb4a9a..ceb2064a6 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -2,6 +2,8 @@ # -*- encoding: utf-8 -*- import random +import socket +from collections import Counter from typing import Union import numpy as np @@ -45,6 +47,7 @@ class ParallelContext(metaclass=SingletonMeta): self.data_parallel_size = 1 self.pipeline_parallel_size = 1 self.tensor_parallel_size = 1 + self.num_processes_on_current_node = -1 self.virtual_pipeline_parallel_size = None self.virtual_pipeline_parallel_rank = None @@ -81,6 +84,13 @@ class ParallelContext(metaclass=SingletonMeta): else: raise TypeError("Invalid type for config, only dictionary or string is supported") + def detect_num_processes_on_current_node(self): + hostname = socket.gethostname() + hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))] + dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL)) + counter = Counter(hostname_list) + self.num_processes_on_current_node = counter[hostname] + @staticmethod def _check_parallel_mode(parallel_mode: ParallelMode): assert isinstance(parallel_mode, ParallelMode) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 9435b37e3..e1da29cb0 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -102,6 +102,9 @@ def launch(config: Union[str, Path, Config, Dict], # if local rank is not given, calculate automatically gpc.set_device(local_rank) + # set the number of processes running on the same node + gpc.detect_num_processes_on_current_node() + gpc.set_seed(seed) if verbose: @@ -398,15 +401,17 @@ def initialize(model: nn.Module, else: scatter_gather = False if use_interleaved: - schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather) + schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, + gpc.config.model.num_chunks, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather) else: schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather) + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather) else: schedule = NonPipelineSchedule() - if gradient_handler_cfg is None: gradient_handlers = None if verbose and not isinstance(model, DDP):