mirror of https://github.com/hpcaitech/ColossalAI
[gemini] add fake_release_chunk for keep-gathered chunk in the inference mode (#2671)
parent
0966008839
commit
8213f89fd2
|
@ -140,6 +140,14 @@ class ChunkManager:
|
||||||
self.__add_memory_usage(chunk.memory_usage)
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def fake_release_chunk(self, chunk: Chunk) -> None:
|
||||||
|
"""Release gathered chunk in a fake mode.
|
||||||
|
This function is used for keep-gathered chunk in the inference mode.
|
||||||
|
"""
|
||||||
|
assert chunk.keep_gathered
|
||||||
|
assert chunk.tensor_state_cnter[TensorState.HOLD] == chunk.num_tensors
|
||||||
|
self.__sub_accessed_chunk(chunk)
|
||||||
|
|
||||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||||
"""
|
"""
|
||||||
Copy data to the chunk.
|
Copy data to the chunk.
|
||||||
|
|
|
@ -257,8 +257,11 @@ class ZeroDDP(ColoDDP):
|
||||||
access_list = list(self.chunk_manager.accessed_chunks)
|
access_list = list(self.chunk_manager.accessed_chunks)
|
||||||
# we need to scatter all accessed chunks and move them to their original places
|
# we need to scatter all accessed chunks and move them to their original places
|
||||||
for chunk in access_list:
|
for chunk in access_list:
|
||||||
assert chunk.can_release
|
if chunk.keep_gathered:
|
||||||
self.chunk_manager.release_chunk(chunk)
|
self.chunk_manager.fake_release_chunk(chunk)
|
||||||
|
else:
|
||||||
|
assert chunk.can_release
|
||||||
|
self.chunk_manager.release_chunk(chunk)
|
||||||
first_param = next(iter(chunk.tensors_info))
|
first_param = next(iter(chunk.tensors_info))
|
||||||
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
|
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
|
||||||
assert self.chunk_manager.accessed_mem == 0
|
assert self.chunk_manager.accessed_mem == 0
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
@ -13,7 +14,7 @@ from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chu
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper
|
||||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
@ -36,9 +37,35 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
|
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||||
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
|
config_dict[world_size]['keep_gathered'] = False
|
||||||
|
if placement_policy != 'cuda':
|
||||||
|
init_device = torch.device('cpu')
|
||||||
|
else:
|
||||||
|
init_device = None
|
||||||
|
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||||
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
|
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def single_chunk_init(model: torch.nn.Module, placement_policy: str):
|
||||||
|
gemini_config = dict(
|
||||||
|
device=get_current_device(),
|
||||||
|
placement_policy=placement_policy,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||||
@parameterize('model_name', ['gpt2'])
|
@parameterize('model_name', ['gpt2'])
|
||||||
def exam_inference(placement_policy, model_name: str):
|
@parameterize('model_init_func', [single_chunk_init, multi_chunk_init])
|
||||||
|
def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable):
|
||||||
set_seed(19360226)
|
set_seed(19360226)
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
@ -56,18 +83,7 @@ def exam_inference(placement_policy, model_name: str):
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
p.data.copy_(torch_p.data)
|
p.data.copy_(torch_p.data)
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
model = model_init_func(model, placement_policy)
|
||||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
|
||||||
config_dict[world_size]['keep_gathered'] = False
|
|
||||||
if placement_policy != 'cuda':
|
|
||||||
init_device = torch.device('cpu')
|
|
||||||
else:
|
|
||||||
init_device = None
|
|
||||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
|
||||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue