[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)

* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12

* [zero] add cpu shard init

* [zero] add tiny example test

* [colo_tensor] fix bugs for torch-1.11
pull/1785/head
HELSON 2022-11-02 16:11:34 +08:00 committed by GitHub
parent 32c1b843a9
commit c6a1a62636
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1041 additions and 951 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,230 +1,237 @@
import torch from collections import deque
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
from collections import deque
import torch
from colossalai.utils import get_current_device
from colossalai.tensor import ColoTensor from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device
class ChunkManager:
""" class ChunkManager:
A manager class to manipulate the tensors in chunks. """
A manager class to manipulate the tensors in chunks.
Args:
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager. Args:
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
""" init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device()
self.size_config: Dict[int, int] = dict() self.device = init_device or get_current_device()
self.kwargs_config = chunk_configuration self.size_config: Dict[int, int] = dict()
for k, v in self.kwargs_config.items(): self.kwargs_config = chunk_configuration
self.size_config[k] = v.pop('chunk_size') for k, v in self.kwargs_config.items():
v['init_device'] = self.device self.size_config[k] = v.pop('chunk_size')
v['init_device'] = self.device
self.chunk_groups: Dict[str, Deque] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() self.chunk_groups: Dict[str, Deque] = dict()
self.accessed_chunks: Set[Chunk] = set() self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_mem: int = 0 self.accessed_chunks: Set[Chunk] = set()
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
"""Append a tensor to a chunk. def append_tensor(self,
tensor: ColoTensor,
Args: group_type: str,
tensor: the tensor appended to the chunk config_key: int,
group_type: the data type of the group cpu_offload: bool = False,
config_key: the key of the group's name, usually the size of the dp world pin_memory: bool = False) -> None:
pin_memory: whether the chunk is pinned in the cpu memory """Append a tensor to a chunk.
"""
assert tensor not in self.tensor_chunk_map Args:
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" tensor: the tensor appended to the chunk
assert config_key in self.size_config group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world
chunk_size = self.size_config[config_key] cpu_offload: if True, the chunk will be closed on CPU
chunk_kwargs = self.kwargs_config[config_key] pin_memory: whether the chunk is pinned in the cpu memory
group_name = "{}_{}".format(group_type, config_key) """
chunk_group = self.__get_chunk_group(group_name) assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
try: assert config_key in self.size_config
# append the tensor to the last chunk
chunk_group[-1].append_tensor(tensor) chunk_size = self.size_config[config_key]
except (IndexError, ChunkFullError): chunk_kwargs = self.kwargs_config[config_key]
# the except statement will be triggered when there is no chunk or group_name = "{}_{}".format(group_type, config_key)
# the last chunk in the chunk group is full chunk_group = self.__get_chunk_group(group_name)
# this will create a new chunk and allocate this chunk to its corresponding process
if chunk_group: try:
# the chunk group is not empty # append the tensor to the last chunk
# close the last chunk chunk_group[-1].append_tensor(tensor)
self.__close_one_chunk(chunk_group[-1]) except (IndexError, ChunkFullError):
# the except statement will be triggered when there is no chunk or
if tensor.numel() > chunk_size: # the last chunk in the chunk group is full
chunk_size = tensor.numel() # this will create a new chunk and allocate this chunk to its corresponding process
chunk = Chunk( if chunk_group:
chunk_size=chunk_size, # the chunk group is not empty
process_group=tensor.process_group, # close the last chunk
dtype=tensor.dtype, self.__close_one_chunk(chunk_group[-1])
pin_memory=pin_memory,
**chunk_kwargs, if tensor.numel() > chunk_size:
) chunk_size = tensor.numel()
chunk = Chunk(
chunk_group.append(chunk) chunk_size=chunk_size,
chunk.append_tensor(tensor) process_group=tensor.process_group,
self.__add_memory_usage(chunk.memory_usage) dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
self.tensor_chunk_map[tensor] = chunk_group[-1] pin_memory=pin_memory,
**chunk_kwargs,
def close_all_groups(self): )
"""Close all the chunks of all groups.
""" chunk_group.append(chunk)
for group_name in self.chunk_groups: chunk.append_tensor(tensor)
self.__close_one_chunk(self.chunk_groups[group_name][-1]) self.__add_memory_usage(chunk.memory_usage)
def access_chunk(self, chunk: Chunk) -> None: self.tensor_chunk_map[tensor] = chunk_group[-1]
"""Make the chunk can be used for calculation.
""" def close_all_groups(self):
if chunk in self.accessed_chunks: """Close all the chunks of all groups.
return """
self.__sub_memroy_usage(chunk.memory_usage) for group_name in self.chunk_groups:
if chunk.device_type == 'cpu': self.__close_one_chunk(self.chunk_groups[group_name][-1])
chunk.shard_move(get_current_device())
self.__add_accessed_chunk(chunk) def access_chunk(self, chunk: Chunk) -> None:
self.__add_memory_usage(chunk.memory_usage) """Make the chunk can be used for calculation.
"""
def release_chunk(self, chunk: Chunk) -> None: if chunk in self.accessed_chunks:
"""Scatter the chunk in CUDA. return
""" self.__sub_memroy_usage(chunk.memory_usage)
if chunk not in self.accessed_chunks: if chunk.device_type == 'cpu':
return chunk.shard_move(get_current_device())
if chunk.can_release: self.__add_accessed_chunk(chunk)
self.__sub_memroy_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage) def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA.
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: """
"""Move the shard of the chunk to the target device. if chunk not in self.accessed_chunks:
""" return
if not chunk.can_move or chunk.device_type == device.type: if chunk.can_release:
return self.__sub_memroy_usage(chunk.memory_usage)
self.__sub_memroy_usage(chunk.memory_usage) self.__sub_accessed_chunk(chunk)
chunk.shard_move(device, force_copy) self.__add_memory_usage(chunk.memory_usage)
self.__add_memory_usage(chunk.memory_usage)
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: """Move the shard of the chunk to the target device.
"""Transit tensor state according to pre-defined state machine. """
""" if not chunk.can_move or chunk.device_type == device.type:
chunk = self.tensor_chunk_map[tensor] return
chunk.tensor_trans_state(tensor, state) self.__sub_memroy_usage(chunk.memory_usage)
chunk.shard_move(device, force_copy)
def reduce_chunk(self, chunk: Chunk) -> bool: self.__add_memory_usage(chunk.memory_usage)
"""Reduce or all reduce the chunk.
""" def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
if not chunk.can_reduce: """Transit tensor state according to pre-defined state machine.
return False """
self.__sub_memroy_usage(chunk.memory_usage) chunk = self.tensor_chunk_map[tensor]
chunk.reduce() chunk.tensor_trans_state(tensor, state)
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage) def reduce_chunk(self, chunk: Chunk) -> bool:
return True """Reduce or all reduce the chunk.
"""
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: if not chunk.can_reduce:
""" return False
Copy data to the chunk. self.__sub_memroy_usage(chunk.memory_usage)
chunk.reduce()
Args: self.__sub_accessed_chunk(chunk)
tensor (torch.Tensor): the tensor used to retrive meta information self.__add_memory_usage(chunk.memory_usage)
data (torch.Tensor): the tensor to be copied to the chunk return True
"""
chunk = self.tensor_chunk_map[tensor] def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
chunk.copy_tensor_to_chunk_slice(tensor, data) """
Copy data to the chunk.
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
""" Args:
Return the chunk owning the tensor. tensor (torch.Tensor): the tensor used to retrive meta information
data (torch.Tensor): the tensor to be copied to the chunk
Args: """
tensor (torch.Tensor): a torch tensor object chunk = self.tensor_chunk_map[tensor]
""" chunk.copy_tensor_to_chunk_slice(tensor, data)
return self.tensor_chunk_map[tensor]
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
def get_cuda_movable_chunks(self) -> List[Chunk]: """
""" Return the chunk owning the tensor.
Get all chunks that can be moved.
""" Args:
chunk_list = [] tensor (torch.Tensor): a torch tensor object
for chunk in self.accessed_chunks: """
if chunk.can_release: return self.tensor_chunk_map[tensor]
chunk_list.append(chunk)
chunk_list.sort(key=lambda x: x.count_id) def get_cuda_movable_chunks(self) -> List[Chunk]:
return chunk_list """
Get all chunks that can be moved.
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: """
""" chunk_list = []
Get all chunks owning the input tensors. for chunk in self.accessed_chunks:
if chunk.can_release:
Args: chunk_list.append(chunk)
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks chunk_list.sort(key=lambda x: x.count_id)
""" return chunk_list
chunks = []
for tensor in tensors: def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
chunk = self.get_chunk(tensor) """
if chunk not in chunks: Get all chunks owning the input tensors.
chunks.append(chunk)
return tuple(chunks) Args:
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: """
"""Add extern static tensor to chunk manager. chunks = []
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. for tensor in tensors:
They are "static", which means their shape, dtype, device never change. chunk = self.get_chunk(tensor)
Thus, their memory usage never changes. if chunk not in chunks:
chunks.append(chunk)
Args: return tuple(chunks)
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
""" def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
assert tensor not in self.tensor_chunk_map """Add extern static tensor to chunk manager.
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
def __repr__(self) -> str: Thus, their memory usage never changes.
msg = [
'Chunk Manager Information:\n', Args:
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
] """
for group_name, group in self.chunk_groups.items(): assert tensor not in self.tensor_chunk_map
msg.append(f'Group {group_name}:\n') self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
for i, chunk in enumerate(group):
msg.append(f'[{i}] {chunk}\n') def __repr__(self) -> str:
return ''.join(msg) msg = [
'Chunk Manager Information:\n',
def __get_chunk_group(self, group_name: str) -> Deque: 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
"""Register a chunk group. ]
""" for group_name, group in self.chunk_groups.items():
if group_name not in self.chunk_groups: msg.append(f'Group {group_name}:\n')
self.chunk_groups[group_name] = deque() for i, chunk in enumerate(group):
return self.chunk_groups[group_name] msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
def __close_one_chunk(self, chunk: Chunk):
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda def __get_chunk_group(self, group_name: str) -> Deque:
self.__sub_memroy_usage(chunk.memory_usage) """Register a chunk group.
chunk.close_chunk(device) """
self.__add_memory_usage(chunk.memory_usage) if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()
def __sub_memroy_usage(self, usage: Dict[str, int]): return self.chunk_groups[group_name]
for k, v in usage.items():
self.total_mem[k] -= v def __close_one_chunk(self, chunk: Chunk):
self.__sub_memroy_usage(chunk.memory_usage)
def __add_memory_usage(self, usage: Dict[str, int]): chunk.close_chunk()
for k, v in usage.items(): self.__add_memory_usage(chunk.memory_usage)
self.total_mem[k] += v
def __sub_memroy_usage(self, usage: Dict[str, int]):
def __add_accessed_chunk(self, chunk: Chunk): for k, v in usage.items():
chunk.access_chunk() self.total_mem[k] -= v
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem def __add_memory_usage(self, usage: Dict[str, int]):
for k, v in usage.items():
def __sub_accessed_chunk(self, chunk: Chunk): self.total_mem[k] += v
chunk.release_chunk()
self.accessed_chunks.remove(chunk) def __add_accessed_chunk(self, chunk: Chunk):
self.accessed_mem -= chunk.chunk_mem chunk.access_chunk()
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem
def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()
self.accessed_chunks.remove(chunk)
self.accessed_mem -= chunk.chunk_mem

View File

@ -1,9 +1,12 @@
import torch
import functools import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time from time import time
from typing import List, Optional, Tuple
import torch
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.gemini.chunk import Chunk, ChunkManager
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from .placement_policy import PlacementPolicyFactory from .placement_policy import PlacementPolicyFactory
@ -25,6 +28,7 @@ class GeminiManager:
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
assert placement_policy in PlacementPolicyFactory.get_polocy_names() assert placement_policy in PlacementPolicyFactory.get_polocy_names()
self.policy_name = placement_policy
policy_cls = PlacementPolicyFactory.create(placement_policy) policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None

View File

@ -1,19 +1,22 @@
import torch
import itertools import itertools
import torch.distributed as dist
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List, Optional, Set
from colossalai.logging import get_dist_logger
from collections import OrderedDict from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from functools import partial
from colossalai.tensor import ProcessGroup as ColoProcessGroup from typing import Dict, Iterable, List, Optional, Set
from .reducer import Reducer
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from .reducer import Reducer
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP):
self.overflow_counter = 0 self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {} self.grads_device: Dict[torch.Tensor, torch.device] = {}
cpu_offload = self.gemini_manager.policy_name != 'cuda'
# TODO: get param order and filter unused params # TODO: get param order and filter unused params
for p in module.parameters(): for p in module.parameters():
assert isinstance(p, ColoParameter) assert isinstance(p, ColoParameter)
@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP):
fp32_data = p.data.float() fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half() p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size() dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory) self.chunk_manager.append_tensor(tensor=p,
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory) group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.append_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p) self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()
@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP):
chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16) chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
self._logger = get_dist_logger() self._logger = get_dist_logger()
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):

View File

@ -1,14 +1,15 @@
from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy from copy import copy
import torch
from functools import lru_cache from functools import lru_cache
from typing import Callable, Optional, Set
from colossalai.tensor import ColoTensorSpec import torch
from colossalai.tensor import ProcessGroup, ReplicaSpec
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from typing import Optional, Set, Callable
from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS
@lru_cache(None) @lru_cache(None)
@ -57,25 +58,26 @@ class ColoTensor(torch.Tensor):
>>> pg = ProcessGroup() >>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()) >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
>>> # The tensor passed in is a tensor after sharding but not a global tensor. >>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0], >>> dims=[0],
>>> num_partitions=[world_size]) >>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec) >>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
Args: Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
""" """
torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
""" """
The signature of the __new__ has to be consistent with the torch.Tensor. The signature of the __new__ has to be consistent with the torch.Tensor.
Args: Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. spec (TensorSpec, optional): the tensor spec of initialization.
Returns: Returns:
ColoTensor: a ColoTensor wrappers the data. ColoTensor: a ColoTensor wrappers the data.
""" """
@ -112,7 +114,7 @@ class ColoTensor(torch.Tensor):
return self.process_group return self.process_group
def set_process_group(self, pg: ProcessGroup): def set_process_group(self, pg: ProcessGroup):
"""set_process_group """set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited. change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid. Only existing pg is DP and dist spec is REPLICaTE is valid.
@ -135,7 +137,7 @@ class ColoTensor(torch.Tensor):
return self.process_group.tp_world_size() return self.process_group.tp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec): def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec """set_dist_spec
set dist spec and change the payloads. set dist spec and change the payloads.
Args: Args:
@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor):
if func in _COLOSSAL_OPS: if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func] func = _COLOSSAL_OPS[func]
if cls.torch_minor >= 12:
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 paramter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions(): if func in _get_my_nowrap_functions():
@ -178,7 +190,7 @@ class ColoTensor(torch.Tensor):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}' return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
def _redistribute(self, dist_spec: _DistSpec) -> None: def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute """_redistribute
Note the function will not handle the logic of backward propagation! Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function. It is used during model tensor initializations as an internal function.
@ -191,12 +203,12 @@ class ColoTensor(torch.Tensor):
self.dist_spec = dist_spec self.dist_spec = dist_spec
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor': def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute """redistribute
Redistribute the tensor among processes. The rule is like this: Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the 1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
DP process group not changed. DP process group not changed.
2. If the pg is not not None and not equal to the current process group. 2. If the pg is not not None and not equal to the current process group.
First, convert the tensor as replicated among the TP process group. First, convert the tensor as replicated among the TP process group.
Second, reset the process group to the new pg. Second, reset the process group to the new pg.
@ -220,7 +232,7 @@ class ColoTensor(torch.Tensor):
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec)) return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
def to_replicate_(self): def to_replicate_(self):
"""to_replicate_ """to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE an inline member function, converting dist spec of the tensor to REPLICATE
""" """

View File

@ -1,15 +1,17 @@
from enum import Enum
from typing import Dict, Set, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.nn.parallel.data_parallel import ZeroDDP from torch.optim import Optimizer
from typing import Dict, Tuple, Set
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.utils import disposable, get_current_device
class OptimState(Enum): class OptimState(Enum):
@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
def get_range_pair(local_chunk: Chunk, local_param: Parameter): def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param] param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
return param_info.offset, param_info.end
begin = max(0, param_info.offset - local_chunk.shard_begin) begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end return begin, end

View File

@ -1,121 +1,124 @@
import torch from functools import partial
import colossalai
import pytest import pytest
import torch.multiprocessing as mp import torch
import torch.distributed as dist import torch.distributed as dist
from functools import partial import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port, get_current_device import colossalai
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.gemini import TensorState
from colossalai.tensor import ColoParameter from colossalai.gemini.chunk import Chunk
from colossalai.gemini import TensorState from colossalai.tensor import ColoParameter
from colossalai.gemini.chunk import Chunk from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
def dist_sum(x):
temp = torch.tensor([x], device=get_current_device())
dist.all_reduce(temp) def dist_sum(x):
return temp.item() temp = torch.tensor([x], device=get_current_device())
dist.all_reduce(temp)
return temp.item()
def add_param(param_list, param_cp_list, *args, **kwargs):
param = ColoParameter(torch.randn(*args, **kwargs))
param_list.append(param) def add_param(param_list, param_cp_list, *args, **kwargs):
param_cp_list.append(param.clone()) param = ColoParameter(torch.randn(*args, **kwargs))
param_list.append(param)
param_cp_list.append(param.clone())
def check_euqal(param, param_cp):
if param.device != param_cp.device:
temp = param.data.to(param_cp.device) def check_euqal(param, param_cp):
else: if param.device != param_cp.device:
temp = param.data temp = param.data.to(param_cp.device)
return torch.equal(temp, param_cp.data) else:
temp = param.data
return torch.equal(temp, param_cp.data)
@parameterize('init_device', [None, torch.device('cpu')])
@parameterize('keep_gathered', [True, False])
@parameterize('pin_memory', [True, False]) @parameterize('init_device', [None, torch.device('cpu')])
def exam_chunk_basic(init_device, keep_gathered, pin_memory): @parameterize('keep_gathered', [True, False])
world_size = torch.distributed.get_world_size() @parameterize('pin_memory', [True, False])
pg = ColoProcessGroup() def exam_chunk_basic(init_device, keep_gathered, pin_memory):
my_chunk = Chunk(chunk_size=1024, world_size = torch.distributed.get_world_size()
process_group=pg, pg = ColoProcessGroup()
dtype=torch.float32, my_chunk = Chunk(chunk_size=1024,
init_device=init_device, process_group=pg,
keep_gathered=keep_gathered, dtype=torch.float32,
pin_memory=pin_memory) init_device=init_device,
cpu_shard_init=True,
param_list = [] keep_gathered=keep_gathered,
param_cp_list = [] pin_memory=pin_memory)
add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') param_list = []
add_param(param_list, param_cp_list, 4, 4) param_cp_list = []
add_param(param_list, param_cp_list, 4, 8, 2, device='cuda')
add_param(param_list, param_cp_list, 1, 1, 5) add_param(param_list, param_cp_list, 8, 8, 8, device='cuda')
add_param(param_list, param_cp_list, 4, 4)
for param in param_list: add_param(param_list, param_cp_list, 4, 8, 2, device='cuda')
my_chunk.append_tensor(param) add_param(param_list, param_cp_list, 1, 1, 5)
assert my_chunk.utilized_size == 597
for param, param_cp in zip(param_list, param_cp_list): for param in param_list:
check_euqal(param, param_cp) my_chunk.append_tensor(param)
my_chunk.close_chunk() assert my_chunk.utilized_size == 597
for param, param_cp in zip(param_list, param_cp_list):
if keep_gathered is False: check_euqal(param, param_cp)
assert my_chunk.cpu_shard.size(0) == 1024 // world_size my_chunk.close_chunk()
assert my_chunk.device_type == 'cpu'
assert my_chunk.can_move if keep_gathered is False:
my_chunk.shard_move(get_current_device()) assert my_chunk.cpu_shard.size(0) == 1024 // world_size
else: assert my_chunk.device_type == 'cpu'
assert my_chunk.chunk_total.size(0) == 1024 assert my_chunk.can_move
assert my_chunk.device_type == 'cuda' my_chunk.shard_move(get_current_device())
assert not my_chunk.can_move else:
assert my_chunk.chunk_total.size(0) == 1024
assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size assert my_chunk.device_type == 'cuda'
flag = my_chunk.has_inf_or_nan assert not my_chunk.can_move
assert not flag, "has_inf_or_nan is {}".format(flag)
assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size
my_chunk.access_chunk() flag = my_chunk.has_inf_or_nan
assert my_chunk.device_type == 'cuda' assert not flag, "has_inf_or_nan is {}".format(flag)
for param, param_cp in zip(param_list, param_cp_list):
check_euqal(param, param_cp) my_chunk.access_chunk()
assert my_chunk.device_type == 'cuda'
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 for param, param_cp in zip(param_list, param_cp_list):
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) check_euqal(param, param_cp)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1 assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
assert not my_chunk.can_release my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
for param in param_list: assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1
my_chunk.tensor_trans_state(param, TensorState.COMPUTE) assert not my_chunk.can_release
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
for param in param_list:
assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4 my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
assert my_chunk.can_reduce my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
my_chunk.reduce()
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4
assert my_chunk.can_reduce
if keep_gathered is False: my_chunk.reduce()
assert my_chunk.cuda_shard.size(0) == 1024 // world_size assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
assert my_chunk.device_type == 'cuda'
assert my_chunk.can_move if keep_gathered is False:
else: assert my_chunk.cuda_shard.size(0) == 1024 // world_size
assert my_chunk.chunk_total.size(0) == 1024 assert my_chunk.device_type == 'cuda'
assert my_chunk.device_type == 'cuda' assert my_chunk.can_move
assert not my_chunk.can_move else:
assert my_chunk.chunk_total.size(0) == 1024
assert my_chunk.device_type == 'cuda'
def run_dist(rank, world_size, port): assert not my_chunk.can_move
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_chunk_basic()
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@pytest.mark.dist exam_chunk_basic()
@pytest.mark.parametrize('world_size', [1, 2, 4])
@rerun_if_address_is_in_use()
def test_chunk_function(world_size): @pytest.mark.dist
run_func = partial(run_dist, world_size=world_size, port=free_port()) @pytest.mark.parametrize('world_size', [1, 2, 4])
mp.spawn(run_func, nprocs=world_size) @rerun_if_address_is_in_use()
def test_chunk_function(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
if __name__ == '__main__': mp.spawn(run_func, nprocs=world_size)
test_chunk_function(4)
if __name__ == '__main__':
test_chunk_function(4)

View File

@ -40,7 +40,8 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy): @parameterize('keep_gather', [False, True])
def exam_gpt_fwd_bwd(placement_policy, keep_gather):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -55,7 +56,7 @@ def exam_gpt_fwd_bwd(placement_policy):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict) chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True) model = ZeroDDP(model, gemini_manager, pin_memory=True)
@ -101,4 +102,4 @@ def test_gpt(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_gpt(1) test_gpt(4)

View File

@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
@ -98,10 +98,55 @@ def exam_gpt_fwd_bwd(placement_policy):
check_param(model, torch_model) check_param(model, torch_model)
@parameterize('placement_policy', ['cuda', 'cpu'])
def exam_tiny_example(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd() exam_gpt_fwd_bwd()
exam_tiny_example()
@pytest.mark.dist @pytest.mark.dist
@ -113,4 +158,4 @@ def test_gpt(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_gpt(1) test_gpt(2)