mirror of https://github.com/hpcaitech/ColossalAI
[zero] chunk manager allows filtering ex-large params (#1393)
parent
adf5054ff8
commit
56b8863b87
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
@ -61,9 +62,6 @@ class ChunkManager:
|
||||||
if isinstance(tensor, ColoTensor):
|
if isinstance(tensor, ColoTensor):
|
||||||
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
|
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
|
||||||
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
|
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
|
||||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
|
||||||
raise ValueError(
|
|
||||||
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
|
|
||||||
try:
|
try:
|
||||||
# append the tensor to the last chunk
|
# append the tensor to the last chunk
|
||||||
self.chunk_groups[group_name][-1].append(tensor)
|
self.chunk_groups[group_name][-1].append(tensor)
|
||||||
|
@ -71,7 +69,10 @@ class ChunkManager:
|
||||||
# the except statement will be triggered when there is no chunk or
|
# the except statement will be triggered when there is no chunk or
|
||||||
# the last chunk in the chunk group is full
|
# the last chunk in the chunk group is full
|
||||||
# this will create a new chunk and allocate this chunk to its corresponding process
|
# this will create a new chunk and allocate this chunk to its corresponding process
|
||||||
chunk_size = self.chunk_size or tensor.numel()
|
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||||
|
chunk_size = tensor.numel()
|
||||||
|
else:
|
||||||
|
chunk_size = self.chunk_size or tensor.numel()
|
||||||
src_rank = self._get_next_src_rank(group_name)
|
src_rank = self._get_next_src_rank(group_name)
|
||||||
chunk = Chunk(chunk_size,
|
chunk = Chunk(chunk_size,
|
||||||
src_rank,
|
src_rank,
|
||||||
|
@ -263,7 +264,8 @@ class ChunkManager:
|
||||||
def search_chunk_size(module: torch.nn.Module,
|
def search_chunk_size(module: torch.nn.Module,
|
||||||
search_range: int,
|
search_range: int,
|
||||||
n_grids: int,
|
n_grids: int,
|
||||||
min_chunk_size: Optional[int] = None) -> int:
|
min_chunk_size: Optional[int] = None,
|
||||||
|
filter_exlarge_params: bool = True) -> int:
|
||||||
"""
|
"""
|
||||||
Search for the chunk size for optimal chunk utilization.
|
Search for the chunk size for optimal chunk utilization.
|
||||||
|
|
||||||
|
@ -278,6 +280,8 @@ class ChunkManager:
|
||||||
assert search_range % n_grids == 0
|
assert search_range % n_grids == 0
|
||||||
# TODO(ver217): sort params and filter unused ones
|
# TODO(ver217): sort params and filter unused ones
|
||||||
params_numel = [p.numel() for p in module.parameters()]
|
params_numel = [p.numel() for p in module.parameters()]
|
||||||
|
if filter_exlarge_params:
|
||||||
|
params_numel = _filter_exlarge_params(params_numel)
|
||||||
max_param_numel = max(params_numel)
|
max_param_numel = max(params_numel)
|
||||||
if min_chunk_size is not None:
|
if min_chunk_size is not None:
|
||||||
assert min_chunk_size >= max_param_numel
|
assert min_chunk_size >= max_param_numel
|
||||||
|
@ -330,3 +334,11 @@ class ChunkManager:
|
||||||
"""
|
"""
|
||||||
assert tensor not in self.tensor_chunk_map
|
assert tensor not in self.tensor_chunk_map
|
||||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_exlarge_params(params_numel: List[int]) -> List[int]:
|
||||||
|
params_numel_arr = np.array(params_numel)
|
||||||
|
std = np.std(params_numel_arr)
|
||||||
|
mean = np.mean(params_numel_arr)
|
||||||
|
upper_limit = mean + 3 * std
|
||||||
|
return list(filter(lambda x: x <= upper_limit, params_numel))
|
||||||
|
|
Loading…
Reference in New Issue