[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)

* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12

* [zero] add cpu shard init

* [zero] add tiny example test

* [colo_tensor] fix bugs for torch-1.11
pull/1785/head
HELSON 2022-11-02 16:11:34 +08:00 committed by GitHub
parent 32c1b843a9
commit c6a1a62636
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1041 additions and 951 deletions

View File

@ -1,11 +1,12 @@
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, List
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device
class TensorState(Enum):
@ -58,6 +59,7 @@ class Chunk:
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False) -> None:
"""
@ -102,6 +104,11 @@ class Chunk:
self.cpu_shard = None
self.is_gathered = True
# configure the init deivce of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size
@ -242,11 +249,8 @@ class Chunk:
self.tensors_state_monitor[tensor_state] += 1
self.utilized_size = new_utilized_size
def close_chunk(self, shard_dev: Optional[torch.device] = None):
def close_chunk(self):
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
Args:
shard_dev: the device where the shard locates
"""
# sanity check
assert self.chunk_temp is not None
@ -265,21 +269,16 @@ class Chunk:
self.chunk_temp = None
self.__scatter()
# always gathered chunk does not have shard
if self.keep_gathered:
if shard_dev is None:
shard_dev = get_current_device()
else:
assert shard_dev.type == 'cuda'
elif shard_dev is None:
shard_dev = torch.device('cpu')
return
if self.pin_memory or shard_dev.type == 'cpu':
if self.pin_memory or self.shard_device.type == 'cpu':
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited
if shard_dev.type == 'cpu':
if self.shard_device.type == 'cpu':
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):

View File

@ -1,10 +1,11 @@
import torch
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
from colossalai.utils import get_current_device
import torch
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.tensor import ColoTensor
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk
from colossalai.utils import get_current_device
class ChunkManager:
@ -31,13 +32,19 @@ class ChunkManager:
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
def append_tensor(self,
tensor: ColoTensor,
group_type: str,
config_key: int,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
"""Append a tensor to a chunk.
Args:
tensor: the tensor appended to the chunk
group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world
cpu_offload: if True, the chunk will be closed on CPU
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert tensor not in self.tensor_chunk_map
@ -67,6 +74,7 @@ class ChunkManager:
chunk_size=chunk_size,
process_group=tensor.process_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
**chunk_kwargs,
)
@ -206,9 +214,8 @@ class ChunkManager:
return self.chunk_groups[group_name]
def __close_one_chunk(self, chunk: Chunk):
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
self.__sub_memroy_usage(chunk.memory_usage)
chunk.close_chunk(device)
chunk.close_chunk()
self.__add_memory_usage(chunk.memory_usage)
def __sub_memroy_usage(self, usage: Dict[str, int]):

View File

@ -1,9 +1,12 @@
import torch
import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time
from typing import List, Optional, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from .placement_policy import PlacementPolicyFactory
@ -25,6 +28,7 @@ class GeminiManager:
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None

View File

@ -1,19 +1,22 @@
import torch
import itertools
import torch.distributed as dist
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List, Optional, Set
from colossalai.logging import get_dist_logger
from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .reducer import Reducer
from functools import partial
from typing import Dict, Iterable, List, Optional, Set
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from .reducer import Reducer
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP):
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}
cpu_offload = self.gemini_manager.policy_name != 'cuda'
# TODO: get param order and filter unused params
for p in module.parameters():
assert isinstance(p, ColoParameter)
@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP):
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
self.chunk_manager.append_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.append_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP):
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
self._logger = get_dist_logger()
def forward(self, *args, **kwargs):

View File

@ -1,14 +1,15 @@
from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy
import torch
from functools import lru_cache
from typing import Callable, Optional, Set
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import ProcessGroup, ReplicaSpec
import torch
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional, Set, Callable
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS
@lru_cache(None)
@ -67,6 +68,7 @@ class ColoTensor(torch.Tensor):
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
"""
torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
"""
@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor):
if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func]
if cls.torch_minor >= 12:
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 paramter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions():

View File

@ -1,15 +1,17 @@
from enum import Enum
from typing import Dict, Set, Tuple
import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from torch.nn import Parameter
from colossalai.nn.parallel.data_parallel import ZeroDDP
from typing import Dict, Tuple, Set
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device
class OptimState(Enum):
@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
return param_info.offset, param_info.end
begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end

View File

@ -1,15 +1,17 @@
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
import torch.distributed as dist
from functools import partial
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port, get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ColoParameter
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.gemini import TensorState
from colossalai.gemini.chunk import Chunk
from colossalai.tensor import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
def dist_sum(x):
@ -42,6 +44,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
process_group=pg,
dtype=torch.float32,
init_device=init_device,
cpu_shard_init=True,
keep_gathered=keep_gathered,
pin_memory=pin_memory)

View File

@ -40,7 +40,8 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
@parameterize('keep_gather', [False, True])
def exam_gpt_fwd_bwd(placement_policy, keep_gather):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -55,7 +56,7 @@ def exam_gpt_fwd_bwd(placement_policy):
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
@ -101,4 +102,4 @@ def test_gpt(world_size):
if __name__ == '__main__':
test_gpt(1)
test_gpt(4)

View File

@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
@ -98,10 +98,55 @@ def exam_gpt_fwd_bwd(placement_policy):
check_param(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
def exam_tiny_example(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
exam_tiny_example()
@pytest.mark.dist
@ -113,4 +158,4 @@ def test_gpt(world_size):
if __name__ == '__main__':
test_gpt(1)
test_gpt(2)