Browse Source

[utils] support detection of number of processes on current node (#723)

pull/729/head
Frank Lee 3 years ago committed by GitHub
parent
commit
04ff5ea546
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      colossalai/context/parallel_context.py
  2. 13
      colossalai/initialize.py

10
colossalai/context/parallel_context.py

@ -2,6 +2,8 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import random import random
import socket
from collections import Counter
from typing import Union from typing import Union
import numpy as np import numpy as np
@ -45,6 +47,7 @@ class ParallelContext(metaclass=SingletonMeta):
self.data_parallel_size = 1 self.data_parallel_size = 1
self.pipeline_parallel_size = 1 self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1 self.tensor_parallel_size = 1
self.num_processes_on_current_node = -1
self.virtual_pipeline_parallel_size = None self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None self.virtual_pipeline_parallel_rank = None
@ -81,6 +84,13 @@ class ParallelContext(metaclass=SingletonMeta):
else: else:
raise TypeError("Invalid type for config, only dictionary or string is supported") raise TypeError("Invalid type for config, only dictionary or string is supported")
def detect_num_processes_on_current_node(self):
hostname = socket.gethostname()
hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))]
dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL))
counter = Counter(hostname_list)
self.num_processes_on_current_node = counter[hostname]
@staticmethod @staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode): def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode) assert isinstance(parallel_mode, ParallelMode)

13
colossalai/initialize.py

@ -102,6 +102,9 @@ def launch(config: Union[str, Path, Config, Dict],
# if local rank is not given, calculate automatically # if local rank is not given, calculate automatically
gpc.set_device(local_rank) gpc.set_device(local_rank)
# set the number of processes running on the same node
gpc.detect_num_processes_on_current_node()
gpc.set_seed(seed) gpc.set_seed(seed)
if verbose: if verbose:
@ -398,15 +401,17 @@ def initialize(model: nn.Module,
else: else:
scatter_gather = False scatter_gather = False
if use_interleaved: if use_interleaved:
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
gpc.config.model.num_chunks, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather) gpc.config.model.num_chunks,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather)
else: else:
schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather) tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather)
else: else:
schedule = NonPipelineSchedule() schedule = NonPipelineSchedule()
if gradient_handler_cfg is None: if gradient_handler_cfg is None:
gradient_handlers = None gradient_handlers = None
if verbose and not isinstance(model, DDP): if verbose and not isinstance(model, DDP):

Loading…
Cancel
Save