[zero] add chunk init function for users (#1729)

* add chunk manager init function

* fix unit tests

* add comment

* add flush=True
pull/1732/head
HELSON 2022-10-18 16:31:22 +08:00 committed by GitHub
parent 2e1dbfb463
commit f69f9bf223
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 691 additions and 629 deletions

View File

@ -1,3 +1,4 @@
from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk
from .manager import ChunkManager
from .search_utils import clasify_params, search_chunk_configuration
from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
from .manager import ChunkManager
from .search_utils import clasify_params, search_chunk_configuration
from .utils import init_chunk_manager

View File

@ -1,100 +1,108 @@
import math
from typing import Dict, List
import numpy as np
import torch.nn as nn
from colossalai.tensor import ColoParameter
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""Filter those parameters whose size is too large from others.
"""
params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)]
params_size_arr = np.array(params_size)
std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
upper_limit = mean + 3 * std
for key in size_dict:
org_list = size_dict[key]
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""Get unused byte for a certain chunk size.
"""
acc = 0
left = 0
for s in size_list:
if s > left:
acc += left
left = chunk_size
left -= s
return left + acc
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in model.parameters():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if getattr(param, '_ddp_to_ignore', False):
continue
param_key = param.process_group.dp_world_size()
if param_key not in params_dict:
params_dict[param_key] = []
params_dict[param_key].append(param)
return params_dict
def search_chunk_configuration(
model: nn.Module,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Dict:
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
params_dict = clasify_params(model)
config_dict: Dict[int, Dict] = dict()
size_dict: Dict[int, List[int]] = dict()
for key in params_dict:
params_list = params_dict[key]
size_list = [p.numel() for p in params_list]
# let small parameters keep gathered in CUDA all the time
total_size = sum(size_list)
if total_size < min_chunk_size_byte:
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
else:
size_dict[key] = size_list
if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)
max_size = min_chunk_size_byte
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
min_chunk_waste = float('+inf')
best_chunk_size = start_size
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
if temp_waste < min_chunk_waste:
min_chunk_waste = temp_waste
best_chunk_size = chunk_size
for key in params_dict:
if key in config_dict:
continue
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict
import math
from typing import Dict, List, Tuple
import numpy as np
import torch.nn as nn
from colossalai.tensor import ColoParameter
def in_ddp(param: nn.Parameter) -> bool:
return not getattr(param, '_ddp_to_ignore', False)
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""Filter those parameters whose size is too large from others.
"""
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
params_size_arr = np.array(params_size)
std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
upper_limit = mean + 3 * std
for key in size_dict:
org_list = size_dict[key]
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""Get unused byte for a certain chunk size.
"""
acc = 0
left = 0
for s in size_list:
if s > left:
acc += left
left = chunk_size
left -= s
return left + acc
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
"""Clasify each parameter by its size of DP group.
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in model.parameters():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param):
continue
param_key = param.process_group.dp_world_size()
if param_key not in params_dict:
params_dict[param_key] = []
params_dict[param_key].append(param)
return params_dict
def search_chunk_configuration(
model: nn.Module,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Tuple[Dict, int]:
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
params_dict = clasify_params(model)
config_dict: Dict[int, Dict] = dict()
size_dict: Dict[int, List[int]] = dict()
for key in params_dict:
params_list = params_dict[key]
size_list = [p.numel() for p in params_list]
# let small parameters keep gathered in CUDA all the time
total_size = sum(size_list)
if total_size < min_chunk_size_byte:
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
else:
size_dict[key] = size_list
if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)
max_size = min_chunk_size_byte
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
min_chunk_waste = float('+inf')
best_chunk_size = start_size
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
if temp_waste < min_chunk_waste:
min_chunk_waste = temp_waste
best_chunk_size = chunk_size
for key in params_dict:
if key in config_dict:
continue
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)
return config_dict, min_chunk_waste

View File

@ -0,0 +1,58 @@
from time import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
kwargs_dict = dict()
if hidden_dim:
search_interval_byte = hidden_dim
else:
search_interval_byte = 1024 # 1kb
kwargs_dict["search_interval_byte"] = search_interval_byte
if search_range_mb:
kwargs_dict["search_range_mb"] = search_range_mb
if min_chunk_size_mb:
kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb
if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
total_size = sum(params_sizes) / 1024**2
dist.barrier()
begine = time()
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
dist.barrier()
end = time()
span_s = end - begine
wasted_size /= 1024**2
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()
chunk_manager = ChunkManager(config_dict, init_device)
return chunk_manager

View File

@ -1,21 +1,23 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable, Type
import torch.distributed as dist
import os
import random
from functools import partial
from typing import Callable, Type
import numpy as np
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
def set_seed(seed):
@ -33,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config = search_chunk_configuration(module, 4, 1024)
chunk_config, _ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)

View File

@ -1,105 +1,104 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ProcessGroup
from tests.test_tensor.common_utils import debug_print
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
chunk_manager.access_chunk(chunk)
for (p0, p1) in zip(model.parameters(), torch_model.parameters()):
assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
pg = ProcessGroup()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
model.eval()
torch_model.eval()
set_seed(pg.dp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 0:
break
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
model.backward(loss)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format(
torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits)
check_grad(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
chunk_manager.access_chunk(chunk)
for (p0, p1) in zip(model.parameters(), torch_model.parameters()):
assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
pg = ProcessGroup()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
model.eval()
torch_model.eval()
set_seed(pg.dp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 0:
break
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
model.backward(loss)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format(
torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits)
check_grad(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)

View File

@ -1,118 +1,116 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from time import time
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)
from functools import partial
from time import time
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)

View File

@ -1,66 +1,65 @@
import pytest
from functools import partial
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.gemini.chunk import search_chunk_configuration
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup
from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_process_group(pg)
p.set_tensor_spec(*tensor_spec)
def exam_search_chunk_size():
world_size = torch.distributed.get_world_size()
pg_tp = ProcessGroup(tp_degree=world_size)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# make sure torch_model and model has the same parameter values
with ColoInitContext(device=get_current_device()):
model = model_builder()
init_1d_row_spec(model, pg_tp)
config_dict = search_chunk_configuration(model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True)
for key in config_dict:
chunk_size = config_dict[key]['chunk_size']
if world_size == 1:
assert chunk_size == 31616
else:
assert chunk_size == 1024
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_search_chunk_size()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_search(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_search(4)
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import search_chunk_configuration
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model, pg: ProcessGroup):
tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_process_group(pg)
p.set_tensor_spec(*tensor_spec)
def exam_search_chunk_size():
world_size = torch.distributed.get_world_size()
pg_tp = ProcessGroup(tp_degree=world_size)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# make sure torch_model and model has the same parameter values
with ColoInitContext(device=get_current_device()):
model = model_builder()
init_1d_row_spec(model, pg_tp)
config_dict, _ = search_chunk_configuration(model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True)
for key in config_dict:
chunk_size = config_dict[key]['chunk_size']
if world_size == 1:
assert chunk_size == 31616
else:
assert chunk_size == 1024
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_search_chunk_size()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_search(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_search(4)

View File

@ -1,110 +1,108 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_load_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False)
zero_dict = model.state_dict(only_rank_0=False)
for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
exam_load_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_ddp(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_ddp(1)
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_load_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False)
zero_dict = model.state_dict(only_rank_0=False)
for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
exam_load_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_ddp(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_ddp(1)

View File

@ -1,97 +1,95 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_zero_optim_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters())
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
set_seed(dist.get_rank() * 3 + 128)
model.train()
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 0:
break
optim.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optim.backward(loss)
optim.step()
optim_state_dict = optim.state_dict()
optim.load_state_dict(optim_state_dict)
new_state = optim.state_dict()['state']
org_state = optim_state_dict['state']
for k, v in org_state.items():
w = new_state[k]
for n, m in v.items():
if isinstance(m, torch.Tensor):
o = w[n]
if m.device != o.device:
o = o.to(m.device)
assert torch.equal(m, o)
else:
assert m == w[n]
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_zero_optim_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_optim(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_optim(1)
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_zero_optim_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters())
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
set_seed(dist.get_rank() * 3 + 128)
model.train()
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 0:
break
optim.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optim.backward(loss)
optim.step()
optim_state_dict = optim.state_dict()
optim.load_state_dict(optim_state_dict)
new_state = optim.state_dict()['state']
org_state = optim_state_dict['state']
for k, v in org_state.items():
w = new_state[k]
for n, m in v.items():
if isinstance(m, torch.Tensor):
o = w[n]
if m.device != o.device:
o = o.to(m.device)
assert torch.equal(m, o)
else:
assert m == w[n]
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_zero_optim_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_optim(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_optim(1)

View File

@ -1,23 +1,24 @@
from functools import partial
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
@ -88,7 +89,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func(model, pg)
dp_world_size = pg.dp_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda':