[utils] support detection of number of processes on current node (#723)

pull/729/head
Frank Lee 3 years ago committed by GitHub
parent 4d90a7b513
commit 04ff5ea546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,8 @@
# -*- encoding: utf-8 -*-
import random
import socket
from collections import Counter
from typing import Union
import numpy as np
@ -45,6 +47,7 @@ class ParallelContext(metaclass=SingletonMeta):
self.data_parallel_size = 1
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
self.num_processes_on_current_node = -1
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
@ -81,6 +84,13 @@ class ParallelContext(metaclass=SingletonMeta):
else:
raise TypeError("Invalid type for config, only dictionary or string is supported")
def detect_num_processes_on_current_node(self):
hostname = socket.gethostname()
hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))]
dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL))
counter = Counter(hostname_list)
self.num_processes_on_current_node = counter[hostname]
@staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode)

@ -102,6 +102,9 @@ def launch(config: Union[str, Path, Config, Dict],
# if local rank is not given, calculate automatically
gpc.set_device(local_rank)
# set the number of processes running on the same node
gpc.detect_num_processes_on_current_node()
gpc.set_seed(seed)
if verbose:
@ -399,14 +402,16 @@ def initialize(model: nn.Module,
scatter_gather = False
if use_interleaved:
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
gpc.config.model.num_chunks,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather)
else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather)
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather)
else:
schedule = NonPipelineSchedule()
if gradient_handler_cfg is None:
gradient_handlers = None
if verbose and not isinstance(model, DDP):

Loading…
Cancel
Save