mirror of https://github.com/hpcaitech/ColossalAI
[zero] add chunk size searching algorithm for parameters in different groups (#1436)
parent
c9427a323f
commit
9056677b13
|
@ -1 +1,2 @@
|
||||||
from .chunkv2 import ChunkV2
|
from .chunkv2 import ChunkV2
|
||||||
|
from .search_utils import clasify_params, search_chunk_configuration
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
from typing import Dict, List
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
from colossalai.tensor import ColoParameter
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
||||||
|
"""Filter those parameters whose size is too large from others.
|
||||||
|
"""
|
||||||
|
params_size = [p.numel() for p in model.parameters()]
|
||||||
|
params_size_arr = np.array(params_size)
|
||||||
|
|
||||||
|
std = np.std(params_size_arr)
|
||||||
|
mean = np.mean(params_size_arr)
|
||||||
|
upper_limit = mean + 3 * std
|
||||||
|
|
||||||
|
for key in size_dict:
|
||||||
|
org_list = size_dict[key]
|
||||||
|
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
||||||
|
"""Get unused byte for a certain chunk size.
|
||||||
|
"""
|
||||||
|
acc = 0
|
||||||
|
left = 0
|
||||||
|
for s in size_list:
|
||||||
|
if s > left:
|
||||||
|
acc += left
|
||||||
|
left = chunk_size
|
||||||
|
left -= s
|
||||||
|
return left + acc
|
||||||
|
|
||||||
|
|
||||||
|
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
|
||||||
|
params_dict: Dict[int, List[ColoParameter]] = dict()
|
||||||
|
for param in model.parameters():
|
||||||
|
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||||
|
param_key = param.process_group.dp_world_size()
|
||||||
|
|
||||||
|
if param_key not in params_dict:
|
||||||
|
params_dict[param_key] = []
|
||||||
|
params_dict[param_key].append(param)
|
||||||
|
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
|
||||||
|
def search_chunk_configuration(
|
||||||
|
model: nn.Module,
|
||||||
|
search_range_mb: int,
|
||||||
|
search_interval_byte: int, # hidden size is the best value for the interval
|
||||||
|
min_chunk_size_mb: int = 32,
|
||||||
|
filter_exlarge_params: bool = True
|
||||||
|
):
|
||||||
|
search_range_byte = search_range_mb * 1024 ** 2
|
||||||
|
min_chunk_size_byte = min_chunk_size_mb * 1024 ** 2
|
||||||
|
assert search_range_byte % search_interval_byte == 0
|
||||||
|
|
||||||
|
params_dict = clasify_params(model)
|
||||||
|
config_dict: Dict[int, Dict] = dict()
|
||||||
|
|
||||||
|
size_dict: Dict[int, List[int]] = dict()
|
||||||
|
for key in params_dict:
|
||||||
|
params_list = params_dict[key]
|
||||||
|
size_list = [p.numel() for p in params_list]
|
||||||
|
# let small parameters keep gathered in CUDA all the time
|
||||||
|
total_size = sum(size_list)
|
||||||
|
if total_size < min_chunk_size_byte:
|
||||||
|
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
|
||||||
|
else:
|
||||||
|
size_dict[key] = size_list
|
||||||
|
|
||||||
|
if filter_exlarge_params:
|
||||||
|
_filter_exlarge_params(model, size_dict)
|
||||||
|
|
||||||
|
max_size = min_chunk_size_byte
|
||||||
|
for key in size_dict:
|
||||||
|
max_size = max(max_size, max(size_dict[key]))
|
||||||
|
|
||||||
|
min_chunk_waste = float('+inf')
|
||||||
|
best_chunk_size = max_size
|
||||||
|
|
||||||
|
for chunk_size in range(max_size, max_size + search_range_byte + 1, search_interval_byte):
|
||||||
|
temp_waste = 0
|
||||||
|
for key in size_dict:
|
||||||
|
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
|
||||||
|
if temp_waste < min_chunk_waste:
|
||||||
|
min_chunk_waste = temp_waste
|
||||||
|
best_chunk_size = chunk_size
|
||||||
|
|
||||||
|
for key in params_dict:
|
||||||
|
if key in config_dict:
|
||||||
|
continue
|
||||||
|
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)
|
||||||
|
|
||||||
|
return config_dict
|
|
@ -0,0 +1,67 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.gemini.update import search_chunk_configuration
|
||||||
|
from colossalai.utils import free_port, get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup
|
||||||
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
|
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||||
|
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if 'weight' in n and 'ln' not in n:
|
||||||
|
p.set_process_group(pg)
|
||||||
|
p.set_tensor_spec(*tensor_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def exam_search_chunk_size():
|
||||||
|
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
pg_tp = ProcessGroup(tp_degree=world_size)
|
||||||
|
|
||||||
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
|
# make sure torch_model and model has the same parameter values
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = model_builder()
|
||||||
|
init_1d_row_spec(model, pg_tp)
|
||||||
|
config_dict = search_chunk_configuration(
|
||||||
|
model,
|
||||||
|
search_range_mb=1,
|
||||||
|
search_interval_byte=16,
|
||||||
|
min_chunk_size_mb=0,
|
||||||
|
filter_exlarge_params=True)
|
||||||
|
|
||||||
|
for key in config_dict:
|
||||||
|
chunk_size = config_dict[key]['chunk_size']
|
||||||
|
if world_size == 1:
|
||||||
|
assert chunk_size == 31616
|
||||||
|
else:
|
||||||
|
assert chunk_size == 1024
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
exam_search_chunk_size()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_search(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_search(4)
|
Loading…
Reference in New Issue