[zero] add chunk init function for users (#1729)

* add chunk manager init function

* fix unit tests

* add comment

* add flush=True
pull/1732/head
HELSON 2022-10-18 16:31:22 +08:00 committed by GitHub
parent 2e1dbfb463
commit f69f9bf223
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 691 additions and 629 deletions

View File

@ -1,3 +1,4 @@
from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk
from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
from .manager import ChunkManager
from .search_utils import clasify_params, search_chunk_configuration
from .utils import init_chunk_manager

View File

@ -1,14 +1,20 @@
import math
from typing import Dict, List
from typing import Dict, List, Tuple
import numpy as np
import torch.nn as nn
from colossalai.tensor import ColoParameter
def in_ddp(param: nn.Parameter) -> bool:
return not getattr(param, '_ddp_to_ignore', False)
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""Filter those parameters whose size is too large from others.
"""
params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)]
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
params_size_arr = np.array(params_size)
std = np.std(params_size_arr)
@ -34,10 +40,12 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
"""Clasify each parameter by its size of DP group.
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in model.parameters():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if getattr(param, '_ddp_to_ignore', False):
if not in_ddp(param):
continue
param_key = param.process_group.dp_world_size()
@ -54,7 +62,7 @@ def search_chunk_configuration(
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Dict:
filter_exlarge_params: bool = True) -> Tuple[Dict, int]:
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
@ -97,4 +105,4 @@ def search_chunk_configuration(
continue
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict
return config_dict, min_chunk_waste

View File

@ -0,0 +1,58 @@
from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
kwargs_dict = dict()
if hidden_dim:
search_interval_byte = hidden_dim
else:
search_interval_byte = 1024 # 1kb
kwargs_dict["search_interval_byte"] = search_interval_byte
if search_range_mb:
kwargs_dict["search_range_mb"] = search_range_mb
if min_chunk_size_mb:
kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb
if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
total_size = sum(params_sizes) / 1024**2
dist.barrier()
begine = time()
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
dist.barrier()
end = time()
span_s = end - begine
wasted_size /= 1024**2
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()
chunk_manager = ChunkManager(config_dict, init_device)
return chunk_manager

View File

@ -1,21 +1,23 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable, Type
import torch.distributed as dist
import os
import random
from functools import partial
from typing import Callable, Type
import numpy as np
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
def set_seed(seed):
@ -33,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config = search_chunk_configuration(module, 4, 1024)
chunk_config, _ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)

View File

@ -1,23 +1,22 @@
from functools import partial
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from tests.test_tensor.common_utils import debug_print
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
@ -54,7 +53,7 @@ def exam_gpt_fwd_bwd(placement_policy):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
chunk_manager = ChunkManager(config_dict)

View File

@ -1,27 +1,25 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from time import time
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
@ -62,7 +60,7 @@ def exam_gpt_fwd_bwd(placement_policy):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':

View File

@ -1,17 +1,16 @@
import pytest
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.gemini.chunk import search_chunk_configuration
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup
from tests.components_to_test.registry import non_distributed_component_funcs
@ -35,7 +34,7 @@ def exam_search_chunk_size():
with ColoInitContext(device=get_current_device()):
model = model_builder()
init_1d_row_spec(model, pg_tp)
config_dict = search_chunk_configuration(model,
config_dict, _ = search_chunk_configuration(model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,

View File

@ -1,22 +1,20 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@ -34,7 +32,7 @@ def exam_state_dict(placement_policy, keep_gathered):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
@ -67,7 +65,7 @@ def exam_load_state_dict(placement_policy, keep_gathered):
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered

View File

@ -1,24 +1,22 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@ -35,7 +33,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered

View File

@ -1,23 +1,24 @@
from functools import partial
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
@ -88,7 +89,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func(model, pg)
dp_world_size = pg.dp_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda':