[zero] chunk manager allows filtering ex-large params (#1393)

pull/1397/head
ver217 2022-08-02 10:40:27 +08:00 committed by GitHub
parent adf5054ff8
commit 56b8863b87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 5 deletions

View File

@ -1,4 +1,5 @@
import torch
import numpy as np
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque
@ -61,9 +62,6 @@ class ChunkManager:
if isinstance(tensor, ColoTensor):
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
raise ValueError(
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
try:
# append the tensor to the last chunk
self.chunk_groups[group_name][-1].append(tensor)
@ -71,7 +69,10 @@ class ChunkManager:
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
chunk_size = self.chunk_size or tensor.numel()
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
chunk_size = tensor.numel()
else:
chunk_size = self.chunk_size or tensor.numel()
src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size,
src_rank,
@ -263,7 +264,8 @@ class ChunkManager:
def search_chunk_size(module: torch.nn.Module,
search_range: int,
n_grids: int,
min_chunk_size: Optional[int] = None) -> int:
min_chunk_size: Optional[int] = None,
filter_exlarge_params: bool = True) -> int:
"""
Search for the chunk size for optimal chunk utilization.
@ -278,6 +280,8 @@ class ChunkManager:
assert search_range % n_grids == 0
# TODO(ver217): sort params and filter unused ones
params_numel = [p.numel() for p in module.parameters()]
if filter_exlarge_params:
params_numel = _filter_exlarge_params(params_numel)
max_param_numel = max(params_numel)
if min_chunk_size is not None:
assert min_chunk_size >= max_param_numel
@ -330,3 +334,11 @@ class ChunkManager:
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
def _filter_exlarge_params(params_numel: List[int]) -> List[int]:
params_numel_arr = np.array(params_numel)
std = np.std(params_numel_arr)
mean = np.mean(params_numel_arr)
upper_limit = mean + 3 * std
return list(filter(lambda x: x <= upper_limit, params_numel))