mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)
* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12 * [zero] add cpu shard init * [zero] add tiny example test * [colo_tensor] fix bugs for torch-1.11pull/1785/head
parent
32c1b843a9
commit
c6a1a62636
|
@ -1,11 +1,12 @@
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
class TensorState(Enum):
|
class TensorState(Enum):
|
||||||
|
@ -58,6 +59,7 @@ class Chunk:
|
||||||
process_group: ColoProcessGroup,
|
process_group: ColoProcessGroup,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
init_device: Optional[torch.device] = None,
|
init_device: Optional[torch.device] = None,
|
||||||
|
cpu_shard_init: bool = False,
|
||||||
keep_gathered: bool = False,
|
keep_gathered: bool = False,
|
||||||
pin_memory: bool = False) -> None:
|
pin_memory: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -102,6 +104,11 @@ class Chunk:
|
||||||
self.cpu_shard = None
|
self.cpu_shard = None
|
||||||
self.is_gathered = True
|
self.is_gathered = True
|
||||||
|
|
||||||
|
# configure the init deivce of the shard
|
||||||
|
# no-offload default: fp16, fp32 -> CUDA
|
||||||
|
# offload default: fp16, fp32 -> CPU
|
||||||
|
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
|
||||||
|
|
||||||
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
|
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
|
||||||
self.shard_mem = self.chunk_mem // self.pg_size
|
self.shard_mem = self.chunk_mem // self.pg_size
|
||||||
|
|
||||||
|
@ -242,11 +249,8 @@ class Chunk:
|
||||||
self.tensors_state_monitor[tensor_state] += 1
|
self.tensors_state_monitor[tensor_state] += 1
|
||||||
self.utilized_size = new_utilized_size
|
self.utilized_size = new_utilized_size
|
||||||
|
|
||||||
def close_chunk(self, shard_dev: Optional[torch.device] = None):
|
def close_chunk(self):
|
||||||
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
|
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
|
||||||
|
|
||||||
Args:
|
|
||||||
shard_dev: the device where the shard locates
|
|
||||||
"""
|
"""
|
||||||
# sanity check
|
# sanity check
|
||||||
assert self.chunk_temp is not None
|
assert self.chunk_temp is not None
|
||||||
|
@ -265,21 +269,16 @@ class Chunk:
|
||||||
self.chunk_temp = None
|
self.chunk_temp = None
|
||||||
|
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
|
# always gathered chunk does not have shard
|
||||||
if self.keep_gathered:
|
if self.keep_gathered:
|
||||||
if shard_dev is None:
|
return
|
||||||
shard_dev = get_current_device()
|
|
||||||
else:
|
|
||||||
assert shard_dev.type == 'cuda'
|
|
||||||
elif shard_dev is None:
|
|
||||||
shard_dev = torch.device('cpu')
|
|
||||||
|
|
||||||
if self.pin_memory or shard_dev.type == 'cpu':
|
if self.pin_memory or self.shard_device.type == 'cpu':
|
||||||
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
|
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
|
||||||
self.cpu_shard.copy_(self.cuda_shard)
|
self.cpu_shard.copy_(self.cuda_shard)
|
||||||
self.cpu_vis_flag = True # cpu_shard has been visited
|
self.cpu_vis_flag = True # cpu_shard has been visited
|
||||||
|
|
||||||
if shard_dev.type == 'cpu':
|
if self.shard_device.type == 'cpu':
|
||||||
self.cuda_shard = None
|
self.cuda_shard = None
|
||||||
|
|
||||||
def shard_move(self, device: torch.device, force_copy: bool = False):
|
def shard_move(self, device: torch.device, force_copy: bool = False):
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import torch
|
|
||||||
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
import torch
|
||||||
|
|
||||||
|
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
class ChunkManager:
|
class ChunkManager:
|
||||||
|
@ -31,13 +32,19 @@ class ChunkManager:
|
||||||
self.accessed_mem: int = 0
|
self.accessed_mem: int = 0
|
||||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||||
|
|
||||||
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
|
def append_tensor(self,
|
||||||
|
tensor: ColoTensor,
|
||||||
|
group_type: str,
|
||||||
|
config_key: int,
|
||||||
|
cpu_offload: bool = False,
|
||||||
|
pin_memory: bool = False) -> None:
|
||||||
"""Append a tensor to a chunk.
|
"""Append a tensor to a chunk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: the tensor appended to the chunk
|
tensor: the tensor appended to the chunk
|
||||||
group_type: the data type of the group
|
group_type: the data type of the group
|
||||||
config_key: the key of the group's name, usually the size of the dp world
|
config_key: the key of the group's name, usually the size of the dp world
|
||||||
|
cpu_offload: if True, the chunk will be closed on CPU
|
||||||
pin_memory: whether the chunk is pinned in the cpu memory
|
pin_memory: whether the chunk is pinned in the cpu memory
|
||||||
"""
|
"""
|
||||||
assert tensor not in self.tensor_chunk_map
|
assert tensor not in self.tensor_chunk_map
|
||||||
|
@ -67,6 +74,7 @@ class ChunkManager:
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
process_group=tensor.process_group,
|
process_group=tensor.process_group,
|
||||||
dtype=tensor.dtype,
|
dtype=tensor.dtype,
|
||||||
|
cpu_shard_init=cpu_offload,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
**chunk_kwargs,
|
**chunk_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -206,9 +214,8 @@ class ChunkManager:
|
||||||
return self.chunk_groups[group_name]
|
return self.chunk_groups[group_name]
|
||||||
|
|
||||||
def __close_one_chunk(self, chunk: Chunk):
|
def __close_one_chunk(self, chunk: Chunk):
|
||||||
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
|
|
||||||
self.__sub_memroy_usage(chunk.memory_usage)
|
self.__sub_memroy_usage(chunk.memory_usage)
|
||||||
chunk.close_chunk(device)
|
chunk.close_chunk()
|
||||||
self.__add_memory_usage(chunk.memory_usage)
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
|
||||||
def __sub_memroy_usage(self, usage: Dict[str, int]):
|
def __sub_memroy_usage(self, usage: Dict[str, int]):
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
import torch
|
|
||||||
import functools
|
import functools
|
||||||
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
from time import time
|
from time import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||||
|
|
||||||
|
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||||
from .placement_policy import PlacementPolicyFactory
|
from .placement_policy import PlacementPolicyFactory
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,6 +28,7 @@ class GeminiManager:
|
||||||
|
|
||||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
||||||
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
||||||
|
self.policy_name = placement_policy
|
||||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
|
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
|
||||||
|
|
|
@ -1,19 +1,22 @@
|
||||||
import torch
|
|
||||||
import itertools
|
import itertools
|
||||||
import torch.distributed as dist
|
|
||||||
from functools import partial
|
|
||||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
|
||||||
from typing import Dict, Iterable, List, Optional, Set
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
from functools import partial
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from typing import Dict, Iterable, List, Optional, Set
|
||||||
from .reducer import Reducer
|
|
||||||
|
|
||||||
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
||||||
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||||
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
|
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||||
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||||
|
|
||||||
|
from .reducer import Reducer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||||
|
@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP):
|
||||||
self.overflow_counter = 0
|
self.overflow_counter = 0
|
||||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||||
|
|
||||||
|
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
||||||
# TODO: get param order and filter unused params
|
# TODO: get param order and filter unused params
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
assert isinstance(p, ColoParameter)
|
assert isinstance(p, ColoParameter)
|
||||||
|
@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP):
|
||||||
fp32_data = p.data.float()
|
fp32_data = p.data.float()
|
||||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||||
p.data = p.data.half()
|
p.data = p.data.half()
|
||||||
|
|
||||||
dp_world_size = p.process_group.dp_world_size()
|
dp_world_size = p.process_group.dp_world_size()
|
||||||
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
|
self.chunk_manager.append_tensor(tensor=p,
|
||||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
|
group_type='fp16_param',
|
||||||
|
config_key=dp_world_size,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
self.chunk_manager.append_tensor(tensor=fp32_p,
|
||||||
|
group_type='fp32_param',
|
||||||
|
config_key=dp_world_size,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
|
pin_memory=pin_memory)
|
||||||
self.fp32_params.append(fp32_p)
|
self.fp32_params.append(fp32_p)
|
||||||
self.grads_device[p] = self.gemini_manager.default_device
|
self.grads_device[p] = self.gemini_manager.default_device
|
||||||
self.chunk_manager.close_all_groups()
|
self.chunk_manager.close_all_groups()
|
||||||
|
@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP):
|
||||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||||
chunk_32.init_pair(chunk_16)
|
chunk_32.init_pair(chunk_16)
|
||||||
|
|
||||||
|
# keep gathered chunks are in CUDA
|
||||||
|
if chunk_16.keep_gathered:
|
||||||
|
self.grads_device[p] = get_current_device()
|
||||||
|
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
from .op_wrapper import _COLOSSAL_OPS
|
|
||||||
from .const import TensorType
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
import torch
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import Callable, Optional, Set
|
||||||
|
|
||||||
from colossalai.tensor import ColoTensorSpec
|
import torch
|
||||||
from colossalai.tensor import ProcessGroup, ReplicaSpec
|
|
||||||
|
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
|
||||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||||
from typing import Optional, Set, Callable
|
|
||||||
|
from .const import TensorType
|
||||||
|
from .op_wrapper import _COLOSSAL_OPS
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(None)
|
@lru_cache(None)
|
||||||
|
@ -67,6 +68,7 @@ class ColoTensor(torch.Tensor):
|
||||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||||
"""
|
"""
|
||||||
|
torch_minor = int(torch.__version__.split('.')[1])
|
||||||
|
|
||||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||||
"""
|
"""
|
||||||
|
@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor):
|
||||||
if func in _COLOSSAL_OPS:
|
if func in _COLOSSAL_OPS:
|
||||||
func = _COLOSSAL_OPS[func]
|
func = _COLOSSAL_OPS[func]
|
||||||
|
|
||||||
|
if cls.torch_minor >= 12:
|
||||||
|
# in order to trigger pre-op hook in the forward of checkpoint module
|
||||||
|
# we have to capture the `backward` function
|
||||||
|
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
||||||
|
if func is torch.Tensor.backward:
|
||||||
|
assert len(args) == 1 # only has 1 paramter
|
||||||
|
backward_tensor = torch.Tensor(args[0])
|
||||||
|
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
|
||||||
|
return backward_tensor.backward(**tensor_kwargs)
|
||||||
|
|
||||||
with torch._C.DisableTorchFunction():
|
with torch._C.DisableTorchFunction():
|
||||||
ret = func(*args, **kwargs)
|
ret = func(*args, **kwargs)
|
||||||
if func in _get_my_nowrap_functions():
|
if func in _get_my_nowrap_functions():
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from enum import Enum
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
from torch.optim import Optimizer
|
||||||
from typing import Dict, Tuple, Set
|
|
||||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||||
|
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from colossalai.utils import get_current_device, disposable
|
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
from colossalai.utils import disposable, get_current_device
|
||||||
|
|
||||||
|
|
||||||
class OptimState(Enum):
|
class OptimState(Enum):
|
||||||
|
@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
|
|
||||||
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
||||||
param_info = local_chunk.tensors_info[local_param]
|
param_info = local_chunk.tensors_info[local_param]
|
||||||
|
if local_chunk.keep_gathered:
|
||||||
|
return param_info.offset, param_info.end
|
||||||
begin = max(0, param_info.offset - local_chunk.shard_begin)
|
begin = max(0, param_info.offset - local_chunk.shard_begin)
|
||||||
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
||||||
return begin, end
|
return begin, end
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
import torch
|
|
||||||
import colossalai
|
|
||||||
import pytest
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.distributed as dist
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
|
||||||
from colossalai.utils import free_port, get_current_device
|
import pytest
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
import torch
|
||||||
from colossalai.tensor import ColoParameter
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.gemini import TensorState
|
from colossalai.gemini import TensorState
|
||||||
from colossalai.gemini.chunk import Chunk
|
from colossalai.gemini.chunk import Chunk
|
||||||
|
from colossalai.tensor import ColoParameter
|
||||||
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port, get_current_device
|
||||||
|
|
||||||
|
|
||||||
def dist_sum(x):
|
def dist_sum(x):
|
||||||
|
@ -42,6 +44,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||||
process_group=pg,
|
process_group=pg,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
init_device=init_device,
|
init_device=init_device,
|
||||||
|
cpu_shard_init=True,
|
||||||
keep_gathered=keep_gathered,
|
keep_gathered=keep_gathered,
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,8 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||||
|
|
||||||
|
|
||||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||||
def exam_gpt_fwd_bwd(placement_policy):
|
@parameterize('keep_gather', [False, True])
|
||||||
|
def exam_gpt_fwd_bwd(placement_policy, keep_gather):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
@ -55,7 +56,7 @@ def exam_gpt_fwd_bwd(placement_policy):
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||||
config_dict[world_size]['chunk_size'] = 5000
|
config_dict[world_size]['chunk_size'] = 5000
|
||||||
config_dict[world_size]['keep_gathered'] = False
|
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||||
chunk_manager = ChunkManager(config_dict)
|
chunk_manager = ChunkManager(config_dict)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||||
|
@ -101,4 +102,4 @@ def test_gpt(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_gpt(1)
|
test_gpt(4)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.amp import convert_to_apex_amp
|
from colossalai.amp import convert_to_apex_amp
|
||||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
|
@ -98,10 +98,55 @@ def exam_gpt_fwd_bwd(placement_policy):
|
||||||
check_param(model, torch_model)
|
check_param(model, torch_model)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||||
|
def exam_tiny_example(placement_policy):
|
||||||
|
set_seed(42)
|
||||||
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = model_builder()
|
||||||
|
|
||||||
|
torch_model = model_builder().cuda()
|
||||||
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
|
torch_p.data.copy_(p.data)
|
||||||
|
|
||||||
|
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
|
||||||
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
|
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||||
|
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
||||||
|
|
||||||
|
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||||
|
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||||
|
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||||
|
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
torch_model.eval()
|
||||||
|
|
||||||
|
set_seed(dist.get_rank() * 3 + 128)
|
||||||
|
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||||
|
if i > 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
|
||||||
|
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||||
|
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
||||||
|
# debug_print([0], zero_logits, torch_logits)
|
||||||
|
|
||||||
|
zero_optim.step()
|
||||||
|
torch_optim.step()
|
||||||
|
|
||||||
|
check_param(model, torch_model)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = {}
|
config = {}
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
exam_gpt_fwd_bwd()
|
exam_gpt_fwd_bwd()
|
||||||
|
exam_tiny_example()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -113,4 +158,4 @@ def test_gpt(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_gpt(1)
|
test_gpt(2)
|
||||||
|
|
Loading…
Reference in New Issue