[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)

* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12

* [zero] add cpu shard init

* [zero] add tiny example test

* [colo_tensor] fix bugs for torch-1.11
pull/1785/head
HELSON 2022-11-02 16:11:34 +08:00 committed by GitHub
parent 32c1b843a9
commit c6a1a62636
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1041 additions and 951 deletions

View File

@ -1,11 +1,12 @@
import torch
import torch.distributed as dist
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Dict, List from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device
class TensorState(Enum): class TensorState(Enum):
@ -58,6 +59,7 @@ class Chunk:
process_group: ColoProcessGroup, process_group: ColoProcessGroup,
dtype: torch.dtype, dtype: torch.dtype,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False, keep_gathered: bool = False,
pin_memory: bool = False) -> None: pin_memory: bool = False) -> None:
""" """
@ -102,6 +104,11 @@ class Chunk:
self.cpu_shard = None self.cpu_shard = None
self.is_gathered = True self.is_gathered = True
# configure the init deivce of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size self.shard_mem = self.chunk_mem // self.pg_size
@ -242,11 +249,8 @@ class Chunk:
self.tensors_state_monitor[tensor_state] += 1 self.tensors_state_monitor[tensor_state] += 1
self.utilized_size = new_utilized_size self.utilized_size = new_utilized_size
def close_chunk(self, shard_dev: Optional[torch.device] = None): def close_chunk(self):
"""Close the chunk. Any tensor can't be appended to a closed chunk later. """Close the chunk. Any tensor can't be appended to a closed chunk later.
Args:
shard_dev: the device where the shard locates
""" """
# sanity check # sanity check
assert self.chunk_temp is not None assert self.chunk_temp is not None
@ -265,21 +269,16 @@ class Chunk:
self.chunk_temp = None self.chunk_temp = None
self.__scatter() self.__scatter()
# always gathered chunk does not have shard
if self.keep_gathered: if self.keep_gathered:
if shard_dev is None: return
shard_dev = get_current_device()
else:
assert shard_dev.type == 'cuda'
elif shard_dev is None:
shard_dev = torch.device('cpu')
if self.pin_memory or shard_dev.type == 'cpu': if self.pin_memory or self.shard_device.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard) self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited self.cpu_vis_flag = True # cpu_shard has been visited
if shard_dev.type == 'cpu': if self.shard_device.type == 'cpu':
self.cuda_shard = None self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False): def shard_move(self, device: torch.device, force_copy: bool = False):

View File

@ -1,10 +1,11 @@
import torch
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
from colossalai.utils import get_current_device import torch
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk from colossalai.utils import get_current_device
class ChunkManager: class ChunkManager:
@ -31,13 +32,19 @@ class ChunkManager:
self.accessed_mem: int = 0 self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None: def append_tensor(self,
tensor: ColoTensor,
group_type: str,
config_key: int,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
"""Append a tensor to a chunk. """Append a tensor to a chunk.
Args: Args:
tensor: the tensor appended to the chunk tensor: the tensor appended to the chunk
group_type: the data type of the group group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world config_key: the key of the group's name, usually the size of the dp world
cpu_offload: if True, the chunk will be closed on CPU
pin_memory: whether the chunk is pinned in the cpu memory pin_memory: whether the chunk is pinned in the cpu memory
""" """
assert tensor not in self.tensor_chunk_map assert tensor not in self.tensor_chunk_map
@ -67,6 +74,7 @@ class ChunkManager:
chunk_size=chunk_size, chunk_size=chunk_size,
process_group=tensor.process_group, process_group=tensor.process_group,
dtype=tensor.dtype, dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory, pin_memory=pin_memory,
**chunk_kwargs, **chunk_kwargs,
) )
@ -206,9 +214,8 @@ class ChunkManager:
return self.chunk_groups[group_name] return self.chunk_groups[group_name]
def __close_one_chunk(self, chunk: Chunk): def __close_one_chunk(self, chunk: Chunk):
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
self.__sub_memroy_usage(chunk.memory_usage) self.__sub_memroy_usage(chunk.memory_usage)
chunk.close_chunk(device) chunk.close_chunk()
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
def __sub_memroy_usage(self, usage: Dict[str, int]): def __sub_memroy_usage(self, usage: Dict[str, int]):

View File

@ -1,9 +1,12 @@
import torch
import functools import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time from time import time
from typing import List, Optional, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from .placement_policy import PlacementPolicyFactory from .placement_policy import PlacementPolicyFactory
@ -25,6 +28,7 @@ class GeminiManager:
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names() assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy) policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None

View File

@ -1,19 +1,22 @@
import torch
import itertools import itertools
import torch.distributed as dist
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List, Optional, Set
from colossalai.logging import get_dist_logger
from collections import OrderedDict from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from functools import partial
from colossalai.tensor import ProcessGroup as ColoProcessGroup from typing import Dict, Iterable, List, Optional, Set
from .reducer import Reducer
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from .reducer import Reducer
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP):
self.overflow_counter = 0 self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {} self.grads_device: Dict[torch.Tensor, torch.device] = {}
cpu_offload = self.gemini_manager.policy_name != 'cuda'
# TODO: get param order and filter unused params # TODO: get param order and filter unused params
for p in module.parameters(): for p in module.parameters():
assert isinstance(p, ColoParameter) assert isinstance(p, ColoParameter)
@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP):
fp32_data = p.data.float() fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half() p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size() dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory) self.chunk_manager.append_tensor(tensor=p,
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory) group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.append_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p) self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()
@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP):
chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16) chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
self._logger = get_dist_logger() self._logger = get_dist_logger()
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):

View File

@ -1,14 +1,15 @@
from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy from copy import copy
import torch
from functools import lru_cache from functools import lru_cache
from typing import Callable, Optional, Set
from colossalai.tensor import ColoTensorSpec import torch
from colossalai.tensor import ProcessGroup, ReplicaSpec
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from typing import Optional, Set, Callable
from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS
@lru_cache(None) @lru_cache(None)
@ -67,6 +68,7 @@ class ColoTensor(torch.Tensor):
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
""" """
torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
""" """
@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor):
if func in _COLOSSAL_OPS: if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func] func = _COLOSSAL_OPS[func]
if cls.torch_minor >= 12:
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 paramter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions(): if func in _get_my_nowrap_functions():

View File

@ -1,15 +1,17 @@
from enum import Enum
from typing import Dict, Set, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.nn.parallel.data_parallel import ZeroDDP from torch.optim import Optimizer
from typing import Dict, Tuple, Set
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.utils import disposable, get_current_device
class OptimState(Enum): class OptimState(Enum):
@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
def get_range_pair(local_chunk: Chunk, local_param: Parameter): def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param] param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
return param_info.offset, param_info.end
begin = max(0, param_info.offset - local_chunk.shard_begin) begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end return begin, end

View File

@ -1,15 +1,17 @@
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
import torch.distributed as dist
from functools import partial from functools import partial
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port, get_current_device import pytest
from colossalai.tensor import ProcessGroup as ColoProcessGroup import torch
from colossalai.tensor import ColoParameter import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini import TensorState from colossalai.gemini import TensorState
from colossalai.gemini.chunk import Chunk from colossalai.gemini.chunk import Chunk
from colossalai.tensor import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
def dist_sum(x): def dist_sum(x):
@ -42,6 +44,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
process_group=pg, process_group=pg,
dtype=torch.float32, dtype=torch.float32,
init_device=init_device, init_device=init_device,
cpu_shard_init=True,
keep_gathered=keep_gathered, keep_gathered=keep_gathered,
pin_memory=pin_memory) pin_memory=pin_memory)

View File

@ -40,7 +40,8 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy): @parameterize('keep_gather', [False, True])
def exam_gpt_fwd_bwd(placement_policy, keep_gather):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -55,7 +56,7 @@ def exam_gpt_fwd_bwd(placement_policy):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True) model = ZeroDDP(model, gemini_manager, pin_memory=True)
@ -101,4 +102,4 @@ def test_gpt(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_gpt(1) test_gpt(4)

View File

@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
@ -98,10 +98,55 @@ def exam_gpt_fwd_bwd(placement_policy):
check_param(model, torch_model) check_param(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
def exam_tiny_example(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd() exam_gpt_fwd_bwd()
exam_tiny_example()
@pytest.mark.dist @pytest.mark.dist
@ -113,4 +158,4 @@ def test_gpt(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_gpt(1) test_gpt(2)