ColossalAI/colossalai/elixir/search/simple.py

121 lines
4.2 KiB
Python

import math
from typing import Tuple
import torch
import torch.nn as nn
from .base import SearchBase
from .result import SearchResult
from .utils import get_multi_used_params, to_divide
class SearchSimple(SearchBase):
"""The simple search algorithm used for unit tests.
Developers can specify the number of chunks used.
args:
module: the module to be searched
default_group_size: the default group size of communications
dtype: the data type of the parameters
prefetch: whether to prefetch the chunks
verbose: whether to print the search process
inp: the example input of the model
step_fn: the example step function of training
"""
def __init__(self,
module: nn.Module,
default_group_size: int,
dtype: torch.dtype = torch.float,
prefetch: bool = False,
verbose: bool = False,
inp=None,
step_fn=None) -> None:
super().__init__(module, dtype, prefetch, verbose, inp, step_fn)
self.default_group_size = default_group_size
def private_truncate(self, param: nn.Parameter) -> int:
return to_divide(param.numel(), self.default_group_size)
def public_trucate(self, length: int) -> int:
return to_divide(length, self.default_group_size)
def search(self, split_number: int, allocate_factor: float) -> Tuple:
# get multi-used parameters
private_params = get_multi_used_params(self.meta_module)
# get parameters used only one time
public_params = [p for p in self.meta_module.parameters() if p not in private_params]
# calculate the size of each group
len_public = len(public_params)
split_number = min(len_public, split_number)
# allocate a list for groups
public_groups = list()
if split_number > 0:
average_size = len_public // split_number
left_size = len_public % split_number
# set the size of each segment
pack_size_list = [average_size] * split_number
for i in range(split_number):
if left_size > 0:
pack_size_list[i] += 1
left_size -= 1
# split public parameters
for i in range(split_number):
p_list = list()
for _ in range(pack_size_list[i]):
p = public_params.pop(0)
p_list.append(p)
public_groups.append(p_list)
assert len(public_params) == 0
# calculate the maximum summarized size
max_sum_size = 0
for p_list in public_groups:
sum_size = sum([p.numel() for p in p_list])
max_sum_size = max(max_sum_size, sum_size)
else:
max_sum_size = 0
self.public_block_size = max_sum_size
self.public_block_number = math.ceil(split_number * allocate_factor)
return (private_params, public_groups)
def simple_search(m: nn.Module,
group_size: int,
split_number: int = 10,
allocate_factor: float = 0.6,
unified_dtype: torch.dtype = torch.float,
shard_device: torch.device = torch.device('cpu'),
prefetch: bool = False,
verbose: bool = False,
inp=None,
step_fn=None) -> SearchResult:
search_class = SearchSimple(
# pre-commit: do not rearrange
module=m,
default_group_size=group_size,
dtype=unified_dtype,
prefetch=prefetch,
verbose=verbose,
inp=inp,
step_fn=step_fn)
private_group, public_groups = search_class.search(split_number, allocate_factor)
chunk_plans = search_class.generate_chunk_plans(private_group, public_groups)
# assign shard device
for plan in chunk_plans:
plan.kwargs['shard_device'] = shard_device
chunk_group = search_class.allocate_chunk_group(chunk_plans)
return SearchResult(chunk_group=chunk_group,
chunk_plans=chunk_plans,
param_called_per_step=search_class.param_per_step)