[zero] chunk manager allows filtering ex-large params (#1393)

pull/1397/head
ver217 2022-08-02 10:40:27 +08:00 committed by GitHub
parent adf5054ff8
commit 56b8863b87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 5 deletions

View File

@ -1,4 +1,5 @@
import torch import torch
import numpy as np
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque from collections import deque
@ -61,9 +62,6 @@ class ChunkManager:
if isinstance(tensor, ColoTensor): if isinstance(tensor, ColoTensor):
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group( assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
), f"Chunk Manager can only manage ColoTensor with the same DP process group" ), f"Chunk Manager can only manage ColoTensor with the same DP process group"
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
raise ValueError(
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
try: try:
# append the tensor to the last chunk # append the tensor to the last chunk
self.chunk_groups[group_name][-1].append(tensor) self.chunk_groups[group_name][-1].append(tensor)
@ -71,7 +69,10 @@ class ChunkManager:
# the except statement will be triggered when there is no chunk or # the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full # the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process # this will create a new chunk and allocate this chunk to its corresponding process
chunk_size = self.chunk_size or tensor.numel() if self.chunk_size is not None and tensor.numel() > self.chunk_size:
chunk_size = tensor.numel()
else:
chunk_size = self.chunk_size or tensor.numel()
src_rank = self._get_next_src_rank(group_name) src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size, chunk = Chunk(chunk_size,
src_rank, src_rank,
@ -263,7 +264,8 @@ class ChunkManager:
def search_chunk_size(module: torch.nn.Module, def search_chunk_size(module: torch.nn.Module,
search_range: int, search_range: int,
n_grids: int, n_grids: int,
min_chunk_size: Optional[int] = None) -> int: min_chunk_size: Optional[int] = None,
filter_exlarge_params: bool = True) -> int:
""" """
Search for the chunk size for optimal chunk utilization. Search for the chunk size for optimal chunk utilization.
@ -278,6 +280,8 @@ class ChunkManager:
assert search_range % n_grids == 0 assert search_range % n_grids == 0
# TODO(ver217): sort params and filter unused ones # TODO(ver217): sort params and filter unused ones
params_numel = [p.numel() for p in module.parameters()] params_numel = [p.numel() for p in module.parameters()]
if filter_exlarge_params:
params_numel = _filter_exlarge_params(params_numel)
max_param_numel = max(params_numel) max_param_numel = max(params_numel)
if min_chunk_size is not None: if min_chunk_size is not None:
assert min_chunk_size >= max_param_numel assert min_chunk_size >= max_param_numel
@ -330,3 +334,11 @@ class ChunkManager:
""" """
assert tensor not in self.tensor_chunk_map assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
def _filter_exlarge_params(params_numel: List[int]) -> List[int]:
params_numel_arr = np.array(params_numel)
std = np.std(params_numel_arr)
mean = np.mean(params_numel_arr)
upper_limit = mean + 3 * std
return list(filter(lambda x: x <= upper_limit, params_numel))