mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix zero's incompatibility with checkpoint in torch-1.12 (#1786)
* [hotfix] fix zero's incompatibility with checkpoint in torch-1.12 * [zero] add cpu shard init * [zero] add tiny example test * [colo_tensor] fix bugs for torch-1.11pull/1785/head
parent
32c1b843a9
commit
c6a1a62636
File diff suppressed because it is too large
Load Diff
|
@ -1,230 +1,237 @@
|
|||
import torch
|
||||
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||
from collections import deque
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
"""
|
||||
A manager class to manipulate the tensors in chunks.
|
||||
|
||||
Args:
|
||||
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
|
||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
|
||||
|
||||
self.device = init_device or get_current_device()
|
||||
self.size_config: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
for k, v in self.kwargs_config.items():
|
||||
self.size_config[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
|
||||
self.chunk_groups: Dict[str, Deque] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
|
||||
"""Append a tensor to a chunk.
|
||||
|
||||
Args:
|
||||
tensor: the tensor appended to the chunk
|
||||
group_type: the data type of the group
|
||||
config_key: the key of the group's name, usually the size of the dp world
|
||||
pin_memory: whether the chunk is pinned in the cpu memory
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||
assert config_key in self.size_config
|
||||
|
||||
chunk_size = self.size_config[config_key]
|
||||
chunk_kwargs = self.kwargs_config[config_key]
|
||||
group_name = "{}_{}".format(group_type, config_key)
|
||||
chunk_group = self.__get_chunk_group(group_name)
|
||||
|
||||
try:
|
||||
# append the tensor to the last chunk
|
||||
chunk_group[-1].append_tensor(tensor)
|
||||
except (IndexError, ChunkFullError):
|
||||
# the except statement will be triggered when there is no chunk or
|
||||
# the last chunk in the chunk group is full
|
||||
# this will create a new chunk and allocate this chunk to its corresponding process
|
||||
if chunk_group:
|
||||
# the chunk group is not empty
|
||||
# close the last chunk
|
||||
self.__close_one_chunk(chunk_group[-1])
|
||||
|
||||
if tensor.numel() > chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
chunk = Chunk(
|
||||
chunk_size=chunk_size,
|
||||
process_group=tensor.process_group,
|
||||
dtype=tensor.dtype,
|
||||
pin_memory=pin_memory,
|
||||
**chunk_kwargs,
|
||||
)
|
||||
|
||||
chunk_group.append(chunk)
|
||||
chunk.append_tensor(tensor)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
self.tensor_chunk_map[tensor] = chunk_group[-1]
|
||||
|
||||
def close_all_groups(self):
|
||||
"""Close all the chunks of all groups.
|
||||
"""
|
||||
for group_name in self.chunk_groups:
|
||||
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
"""Make the chunk can be used for calculation.
|
||||
"""
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
if chunk.device_type == 'cpu':
|
||||
chunk.shard_move(get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""Scatter the chunk in CUDA.
|
||||
"""
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
self.__sub_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
||||
"""Move the shard of the chunk to the target device.
|
||||
"""
|
||||
if not chunk.can_move or chunk.device_type == device.type:
|
||||
return
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.shard_move(device, force_copy)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
"""Transit tensor state according to pre-defined state machine.
|
||||
"""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Reduce or all reduce the chunk.
|
||||
"""
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.reduce()
|
||||
self.__sub_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return True
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||
"""
|
||||
Copy data to the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
||||
data (torch.Tensor): the tensor to be copied to the chunk
|
||||
"""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
||||
|
||||
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||
"""
|
||||
Return the chunk owning the tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a torch tensor object
|
||||
"""
|
||||
return self.tensor_chunk_map[tensor]
|
||||
|
||||
def get_cuda_movable_chunks(self) -> List[Chunk]:
|
||||
"""
|
||||
Get all chunks that can be moved.
|
||||
"""
|
||||
chunk_list = []
|
||||
for chunk in self.accessed_chunks:
|
||||
if chunk.can_release:
|
||||
chunk_list.append(chunk)
|
||||
chunk_list.sort(key=lambda x: x.count_id)
|
||||
return chunk_list
|
||||
|
||||
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||
"""
|
||||
Get all chunks owning the input tensors.
|
||||
|
||||
Args:
|
||||
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
|
||||
"""
|
||||
chunks = []
|
||||
for tensor in tensors:
|
||||
chunk = self.get_chunk(tensor)
|
||||
if chunk not in chunks:
|
||||
chunks.append(chunk)
|
||||
return tuple(chunks)
|
||||
|
||||
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
|
||||
"""Add extern static tensor to chunk manager.
|
||||
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
|
||||
They are "static", which means their shape, dtype, device never change.
|
||||
Thus, their memory usage never changes.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = [
|
||||
'Chunk Manager Information:\n',
|
||||
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
]
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg.append(f'Group {group_name}:\n')
|
||||
for i, chunk in enumerate(group):
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
if group_name not in self.chunk_groups:
|
||||
self.chunk_groups[group_name] = deque()
|
||||
return self.chunk_groups[group_name]
|
||||
|
||||
def __close_one_chunk(self, chunk: Chunk):
|
||||
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.close_chunk(device)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def __sub_memroy_usage(self, usage: Dict[str, int]):
|
||||
for k, v in usage.items():
|
||||
self.total_mem[k] -= v
|
||||
|
||||
def __add_memory_usage(self, usage: Dict[str, int]):
|
||||
for k, v in usage.items():
|
||||
self.total_mem[k] += v
|
||||
|
||||
def __add_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.access_chunk()
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.accessed_mem += chunk.chunk_mem
|
||||
|
||||
def __sub_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.release_chunk()
|
||||
self.accessed_chunks.remove(chunk)
|
||||
self.accessed_mem -= chunk.chunk_mem
|
||||
from collections import deque
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
"""
|
||||
A manager class to manipulate the tensors in chunks.
|
||||
|
||||
Args:
|
||||
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
|
||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
|
||||
|
||||
self.device = init_device or get_current_device()
|
||||
self.size_config: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
for k, v in self.kwargs_config.items():
|
||||
self.size_config[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
|
||||
self.chunk_groups: Dict[str, Deque] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def append_tensor(self,
|
||||
tensor: ColoTensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
"""Append a tensor to a chunk.
|
||||
|
||||
Args:
|
||||
tensor: the tensor appended to the chunk
|
||||
group_type: the data type of the group
|
||||
config_key: the key of the group's name, usually the size of the dp world
|
||||
cpu_offload: if True, the chunk will be closed on CPU
|
||||
pin_memory: whether the chunk is pinned in the cpu memory
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||
assert config_key in self.size_config
|
||||
|
||||
chunk_size = self.size_config[config_key]
|
||||
chunk_kwargs = self.kwargs_config[config_key]
|
||||
group_name = "{}_{}".format(group_type, config_key)
|
||||
chunk_group = self.__get_chunk_group(group_name)
|
||||
|
||||
try:
|
||||
# append the tensor to the last chunk
|
||||
chunk_group[-1].append_tensor(tensor)
|
||||
except (IndexError, ChunkFullError):
|
||||
# the except statement will be triggered when there is no chunk or
|
||||
# the last chunk in the chunk group is full
|
||||
# this will create a new chunk and allocate this chunk to its corresponding process
|
||||
if chunk_group:
|
||||
# the chunk group is not empty
|
||||
# close the last chunk
|
||||
self.__close_one_chunk(chunk_group[-1])
|
||||
|
||||
if tensor.numel() > chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
chunk = Chunk(
|
||||
chunk_size=chunk_size,
|
||||
process_group=tensor.process_group,
|
||||
dtype=tensor.dtype,
|
||||
cpu_shard_init=cpu_offload,
|
||||
pin_memory=pin_memory,
|
||||
**chunk_kwargs,
|
||||
)
|
||||
|
||||
chunk_group.append(chunk)
|
||||
chunk.append_tensor(tensor)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
self.tensor_chunk_map[tensor] = chunk_group[-1]
|
||||
|
||||
def close_all_groups(self):
|
||||
"""Close all the chunks of all groups.
|
||||
"""
|
||||
for group_name in self.chunk_groups:
|
||||
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
"""Make the chunk can be used for calculation.
|
||||
"""
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
if chunk.device_type == 'cpu':
|
||||
chunk.shard_move(get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""Scatter the chunk in CUDA.
|
||||
"""
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
self.__sub_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
||||
"""Move the shard of the chunk to the target device.
|
||||
"""
|
||||
if not chunk.can_move or chunk.device_type == device.type:
|
||||
return
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.shard_move(device, force_copy)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
"""Transit tensor state according to pre-defined state machine.
|
||||
"""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Reduce or all reduce the chunk.
|
||||
"""
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.reduce()
|
||||
self.__sub_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return True
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||
"""
|
||||
Copy data to the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
||||
data (torch.Tensor): the tensor to be copied to the chunk
|
||||
"""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
||||
|
||||
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||
"""
|
||||
Return the chunk owning the tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a torch tensor object
|
||||
"""
|
||||
return self.tensor_chunk_map[tensor]
|
||||
|
||||
def get_cuda_movable_chunks(self) -> List[Chunk]:
|
||||
"""
|
||||
Get all chunks that can be moved.
|
||||
"""
|
||||
chunk_list = []
|
||||
for chunk in self.accessed_chunks:
|
||||
if chunk.can_release:
|
||||
chunk_list.append(chunk)
|
||||
chunk_list.sort(key=lambda x: x.count_id)
|
||||
return chunk_list
|
||||
|
||||
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||
"""
|
||||
Get all chunks owning the input tensors.
|
||||
|
||||
Args:
|
||||
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
|
||||
"""
|
||||
chunks = []
|
||||
for tensor in tensors:
|
||||
chunk = self.get_chunk(tensor)
|
||||
if chunk not in chunks:
|
||||
chunks.append(chunk)
|
||||
return tuple(chunks)
|
||||
|
||||
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
|
||||
"""Add extern static tensor to chunk manager.
|
||||
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
|
||||
They are "static", which means their shape, dtype, device never change.
|
||||
Thus, their memory usage never changes.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = [
|
||||
'Chunk Manager Information:\n',
|
||||
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
]
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg.append(f'Group {group_name}:\n')
|
||||
for i, chunk in enumerate(group):
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
if group_name not in self.chunk_groups:
|
||||
self.chunk_groups[group_name] = deque()
|
||||
return self.chunk_groups[group_name]
|
||||
|
||||
def __close_one_chunk(self, chunk: Chunk):
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.close_chunk()
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def __sub_memroy_usage(self, usage: Dict[str, int]):
|
||||
for k, v in usage.items():
|
||||
self.total_mem[k] -= v
|
||||
|
||||
def __add_memory_usage(self, usage: Dict[str, int]):
|
||||
for k, v in usage.items():
|
||||
self.total_mem[k] += v
|
||||
|
||||
def __add_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.access_chunk()
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.accessed_mem += chunk.chunk_mem
|
||||
|
||||
def __sub_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.release_chunk()
|
||||
self.accessed_chunks.remove(chunk)
|
||||
self.accessed_mem -= chunk.chunk_mem
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import torch
|
||||
import functools
|
||||
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||
from typing import List, Optional, Tuple
|
||||
from time import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
|
||||
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||
from .placement_policy import PlacementPolicyFactory
|
||||
|
||||
|
||||
|
@ -25,6 +28,7 @@ class GeminiManager:
|
|||
|
||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
||||
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
||||
self.policy_name = placement_policy
|
||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||
self._chunk_manager = chunk_manager
|
||||
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
import torch
|
||||
import itertools
|
||||
import torch.distributed as dist
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from .reducer import Reducer
|
||||
from functools import partial
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
|
||||
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
|
||||
from .reducer import Reducer
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
|
@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP):
|
|||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
|
||||
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP):
|
|||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
p.data = p.data.half()
|
||||
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
|
||||
self.chunk_manager.append_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.append_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP):
|
|||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk_16.keep_gathered:
|
||||
self.grads_device[p] = get_current_device()
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from .op_wrapper import _COLOSSAL_OPS
|
||||
from .const import TensorType
|
||||
from copy import copy
|
||||
import torch
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional, Set
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor import ProcessGroup, ReplicaSpec
|
||||
import torch
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from typing import Optional, Set, Callable
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
|
||||
from .const import TensorType
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
|
@ -57,25 +58,26 @@ class ColoTensor(torch.Tensor):
|
|||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())
|
||||
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||
"""
|
||||
torch_minor = int(torch.__version__.split('.')[1])
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||
"""
|
||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||
|
||||
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization.
|
||||
|
||||
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor wrappers the data.
|
||||
"""
|
||||
|
@ -112,7 +114,7 @@ class ColoTensor(torch.Tensor):
|
|||
return self.process_group
|
||||
|
||||
def set_process_group(self, pg: ProcessGroup):
|
||||
"""set_process_group
|
||||
"""set_process_group
|
||||
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
||||
Only existing pg is DP and dist spec is REPLICaTE is valid.
|
||||
|
||||
|
@ -135,7 +137,7 @@ class ColoTensor(torch.Tensor):
|
|||
return self.process_group.tp_world_size()
|
||||
|
||||
def set_dist_spec(self, dist_spec: _DistSpec):
|
||||
"""set_dist_spec
|
||||
"""set_dist_spec
|
||||
set dist spec and change the payloads.
|
||||
|
||||
Args:
|
||||
|
@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor):
|
|||
if func in _COLOSSAL_OPS:
|
||||
func = _COLOSSAL_OPS[func]
|
||||
|
||||
if cls.torch_minor >= 12:
|
||||
# in order to trigger pre-op hook in the forward of checkpoint module
|
||||
# we have to capture the `backward` function
|
||||
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
||||
if func is torch.Tensor.backward:
|
||||
assert len(args) == 1 # only has 1 paramter
|
||||
backward_tensor = torch.Tensor(args[0])
|
||||
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
|
||||
return backward_tensor.backward(**tensor_kwargs)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
if func in _get_my_nowrap_functions():
|
||||
|
@ -178,7 +190,7 @@ class ColoTensor(torch.Tensor):
|
|||
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
|
||||
|
||||
def _redistribute(self, dist_spec: _DistSpec) -> None:
|
||||
"""_redistribute
|
||||
"""_redistribute
|
||||
Note the function will not handle the logic of backward propagation!
|
||||
It is used during model tensor initializations as an internal function.
|
||||
|
||||
|
@ -191,12 +203,12 @@ class ColoTensor(torch.Tensor):
|
|||
self.dist_spec = dist_spec
|
||||
|
||||
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
|
||||
"""redistribute
|
||||
"""redistribute
|
||||
Redistribute the tensor among processes. The rule is like this:
|
||||
|
||||
|
||||
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
|
||||
DP process group not changed.
|
||||
|
||||
|
||||
2. If the pg is not not None and not equal to the current process group.
|
||||
First, convert the tensor as replicated among the TP process group.
|
||||
Second, reset the process group to the new pg.
|
||||
|
@ -220,7 +232,7 @@ class ColoTensor(torch.Tensor):
|
|||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
|
||||
|
||||
def to_replicate_(self):
|
||||
"""to_replicate_
|
||||
"""to_replicate_
|
||||
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
from enum import Enum
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from enum import Enum
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn import Parameter
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from typing import Dict, Tuple, Set
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from colossalai.utils import disposable, get_current_device
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
||||
param_info = local_chunk.tensors_info[local_param]
|
||||
if local_chunk.keep_gathered:
|
||||
return param_info.offset, param_info.end
|
||||
begin = max(0, param_info.offset - local_chunk.shard_begin)
|
||||
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
||||
return begin, end
|
||||
|
|
|
@ -1,121 +1,124 @@
|
|||
import torch
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from functools import partial
|
||||
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
|
||||
|
||||
def dist_sum(x):
|
||||
temp = torch.tensor([x], device=get_current_device())
|
||||
dist.all_reduce(temp)
|
||||
return temp.item()
|
||||
|
||||
|
||||
def add_param(param_list, param_cp_list, *args, **kwargs):
|
||||
param = ColoParameter(torch.randn(*args, **kwargs))
|
||||
param_list.append(param)
|
||||
param_cp_list.append(param.clone())
|
||||
|
||||
|
||||
def check_euqal(param, param_cp):
|
||||
if param.device != param_cp.device:
|
||||
temp = param.data.to(param_cp.device)
|
||||
else:
|
||||
temp = param.data
|
||||
return torch.equal(temp, param_cp.data)
|
||||
|
||||
|
||||
@parameterize('init_device', [None, torch.device('cpu')])
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('pin_memory', [True, False])
|
||||
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ColoProcessGroup()
|
||||
my_chunk = Chunk(chunk_size=1024,
|
||||
process_group=pg,
|
||||
dtype=torch.float32,
|
||||
init_device=init_device,
|
||||
keep_gathered=keep_gathered,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
param_list = []
|
||||
param_cp_list = []
|
||||
|
||||
add_param(param_list, param_cp_list, 8, 8, 8, device='cuda')
|
||||
add_param(param_list, param_cp_list, 4, 4)
|
||||
add_param(param_list, param_cp_list, 4, 8, 2, device='cuda')
|
||||
add_param(param_list, param_cp_list, 1, 1, 5)
|
||||
|
||||
for param in param_list:
|
||||
my_chunk.append_tensor(param)
|
||||
assert my_chunk.utilized_size == 597
|
||||
for param, param_cp in zip(param_list, param_cp_list):
|
||||
check_euqal(param, param_cp)
|
||||
my_chunk.close_chunk()
|
||||
|
||||
if keep_gathered is False:
|
||||
assert my_chunk.cpu_shard.size(0) == 1024 // world_size
|
||||
assert my_chunk.device_type == 'cpu'
|
||||
assert my_chunk.can_move
|
||||
my_chunk.shard_move(get_current_device())
|
||||
else:
|
||||
assert my_chunk.chunk_total.size(0) == 1024
|
||||
assert my_chunk.device_type == 'cuda'
|
||||
assert not my_chunk.can_move
|
||||
|
||||
assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size
|
||||
flag = my_chunk.has_inf_or_nan
|
||||
assert not flag, "has_inf_or_nan is {}".format(flag)
|
||||
|
||||
my_chunk.access_chunk()
|
||||
assert my_chunk.device_type == 'cuda'
|
||||
for param, param_cp in zip(param_list, param_cp_list):
|
||||
check_euqal(param, param_cp)
|
||||
|
||||
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
|
||||
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
|
||||
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
|
||||
assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1
|
||||
assert not my_chunk.can_release
|
||||
|
||||
for param in param_list:
|
||||
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
|
||||
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
||||
|
||||
assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4
|
||||
assert my_chunk.can_reduce
|
||||
my_chunk.reduce()
|
||||
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
|
||||
|
||||
if keep_gathered is False:
|
||||
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
|
||||
assert my_chunk.device_type == 'cuda'
|
||||
assert my_chunk.can_move
|
||||
else:
|
||||
assert my_chunk.chunk_total.size(0) == 1024
|
||||
assert my_chunk.device_type == 'cuda'
|
||||
assert not my_chunk.can_move
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_chunk_basic()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_chunk_function(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_function(4)
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
|
||||
|
||||
def dist_sum(x):
|
||||
temp = torch.tensor([x], device=get_current_device())
|
||||
dist.all_reduce(temp)
|
||||
return temp.item()
|
||||
|
||||
|
||||
def add_param(param_list, param_cp_list, *args, **kwargs):
|
||||
param = ColoParameter(torch.randn(*args, **kwargs))
|
||||
param_list.append(param)
|
||||
param_cp_list.append(param.clone())
|
||||
|
||||
|
||||
def check_euqal(param, param_cp):
|
||||
if param.device != param_cp.device:
|
||||
temp = param.data.to(param_cp.device)
|
||||
else:
|
||||
temp = param.data
|
||||
return torch.equal(temp, param_cp.data)
|
||||
|
||||
|
||||
@parameterize('init_device', [None, torch.device('cpu')])
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('pin_memory', [True, False])
|
||||
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ColoProcessGroup()
|
||||
my_chunk = Chunk(chunk_size=1024,
|
||||
process_group=pg,
|
||||
dtype=torch.float32,
|
||||
init_device=init_device,
|
||||
cpu_shard_init=True,
|
||||
keep_gathered=keep_gathered,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
param_list = []
|
||||
param_cp_list = []
|
||||
|
||||
add_param(param_list, param_cp_list, 8, 8, 8, device='cuda')
|
||||
add_param(param_list, param_cp_list, 4, 4)
|
||||
add_param(param_list, param_cp_list, 4, 8, 2, device='cuda')
|
||||
add_param(param_list, param_cp_list, 1, 1, 5)
|
||||
|
||||
for param in param_list:
|
||||
my_chunk.append_tensor(param)
|
||||
assert my_chunk.utilized_size == 597
|
||||
for param, param_cp in zip(param_list, param_cp_list):
|
||||
check_euqal(param, param_cp)
|
||||
my_chunk.close_chunk()
|
||||
|
||||
if keep_gathered is False:
|
||||
assert my_chunk.cpu_shard.size(0) == 1024 // world_size
|
||||
assert my_chunk.device_type == 'cpu'
|
||||
assert my_chunk.can_move
|
||||
my_chunk.shard_move(get_current_device())
|
||||
else:
|
||||
assert my_chunk.chunk_total.size(0) == 1024
|
||||
assert my_chunk.device_type == 'cuda'
|
||||
assert not my_chunk.can_move
|
||||
|
||||
assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size
|
||||
flag = my_chunk.has_inf_or_nan
|
||||
assert not flag, "has_inf_or_nan is {}".format(flag)
|
||||
|
||||
my_chunk.access_chunk()
|
||||
assert my_chunk.device_type == 'cuda'
|
||||
for param, param_cp in zip(param_list, param_cp_list):
|
||||
check_euqal(param, param_cp)
|
||||
|
||||
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
|
||||
my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE)
|
||||
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3
|
||||
assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1
|
||||
assert not my_chunk.can_release
|
||||
|
||||
for param in param_list:
|
||||
my_chunk.tensor_trans_state(param, TensorState.COMPUTE)
|
||||
my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
||||
|
||||
assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4
|
||||
assert my_chunk.can_reduce
|
||||
my_chunk.reduce()
|
||||
assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4
|
||||
|
||||
if keep_gathered is False:
|
||||
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
|
||||
assert my_chunk.device_type == 'cuda'
|
||||
assert my_chunk.can_move
|
||||
else:
|
||||
assert my_chunk.chunk_total.size(0) == 1024
|
||||
assert my_chunk.device_type == 'cuda'
|
||||
assert not my_chunk.can_move
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_chunk_basic()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_chunk_function(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_function(4)
|
||||
|
|
|
@ -40,7 +40,8 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
|||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
def exam_gpt_fwd_bwd(placement_policy):
|
||||
@parameterize('keep_gather', [False, True])
|
||||
def exam_gpt_fwd_bwd(placement_policy, keep_gather):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -55,7 +56,7 @@ def exam_gpt_fwd_bwd(placement_policy):
|
|||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
config_dict[world_size]['keep_gathered'] = keep_gather
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
@ -101,4 +102,4 @@ def test_gpt(world_size):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(1)
|
||||
test_gpt(4)
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
|
@ -98,10 +98,55 @@ def exam_gpt_fwd_bwd(placement_policy):
|
|||
check_param(model, torch_model)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
def exam_tiny_example(placement_policy):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
||||
# debug_print([0], zero_logits, torch_logits)
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
check_param(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_gpt_fwd_bwd()
|
||||
exam_tiny_example()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@ -113,4 +158,4 @@ def test_gpt(world_size):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(1)
|
||||
test_gpt(2)
|
||||
|
|
Loading…
Reference in New Issue