mirror of https://github.com/hpcaitech/ColossalAI
[zero] chunk manager allows filtering ex-large params (#1393)
parent
adf5054ff8
commit
56b8863b87
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||
from collections import deque
|
||||
|
||||
|
@ -61,9 +62,6 @@ class ChunkManager:
|
|||
if isinstance(tensor, ColoTensor):
|
||||
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
|
||||
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
|
||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||
raise ValueError(
|
||||
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
|
||||
try:
|
||||
# append the tensor to the last chunk
|
||||
self.chunk_groups[group_name][-1].append(tensor)
|
||||
|
@ -71,7 +69,10 @@ class ChunkManager:
|
|||
# the except statement will be triggered when there is no chunk or
|
||||
# the last chunk in the chunk group is full
|
||||
# this will create a new chunk and allocate this chunk to its corresponding process
|
||||
chunk_size = self.chunk_size or tensor.numel()
|
||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
else:
|
||||
chunk_size = self.chunk_size or tensor.numel()
|
||||
src_rank = self._get_next_src_rank(group_name)
|
||||
chunk = Chunk(chunk_size,
|
||||
src_rank,
|
||||
|
@ -263,7 +264,8 @@ class ChunkManager:
|
|||
def search_chunk_size(module: torch.nn.Module,
|
||||
search_range: int,
|
||||
n_grids: int,
|
||||
min_chunk_size: Optional[int] = None) -> int:
|
||||
min_chunk_size: Optional[int] = None,
|
||||
filter_exlarge_params: bool = True) -> int:
|
||||
"""
|
||||
Search for the chunk size for optimal chunk utilization.
|
||||
|
||||
|
@ -278,6 +280,8 @@ class ChunkManager:
|
|||
assert search_range % n_grids == 0
|
||||
# TODO(ver217): sort params and filter unused ones
|
||||
params_numel = [p.numel() for p in module.parameters()]
|
||||
if filter_exlarge_params:
|
||||
params_numel = _filter_exlarge_params(params_numel)
|
||||
max_param_numel = max(params_numel)
|
||||
if min_chunk_size is not None:
|
||||
assert min_chunk_size >= max_param_numel
|
||||
|
@ -330,3 +334,11 @@ class ChunkManager:
|
|||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||
|
||||
|
||||
def _filter_exlarge_params(params_numel: List[int]) -> List[int]:
|
||||
params_numel_arr = np.array(params_numel)
|
||||
std = np.std(params_numel_arr)
|
||||
mean = np.mean(params_numel_arr)
|
||||
upper_limit = mean + 3 * std
|
||||
return list(filter(lambda x: x <= upper_limit, params_numel))
|
||||
|
|
Loading…
Reference in New Issue