mirror of https://github.com/hpcaitech/ColossalAI
[feature] A new ZeRO implementation (#1644)
parent
b1be5b88bd
commit
b28991dd0a
|
@ -1,10 +1,6 @@
|
|||
from .chunk import TensorInfo, Chunk, TensorState
|
||||
from .chunk_mgr import ChunkManager
|
||||
from .chunk import TensorInfo, TensorState
|
||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import TensorPlacementPolicyFactory
|
||||
from .gemini_mgr import GeminiManager
|
||||
|
||||
__all__ = [
|
||||
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'ChunkManager', 'TensorInfo', 'Chunk',
|
||||
'TensorState'
|
||||
]
|
||||
__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState']
|
||||
|
|
|
@ -1,316 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
COMPUTE = 1
|
||||
HOLD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
||||
TensorState.HOLD))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
state: TensorState
|
||||
offset: int
|
||||
end: int
|
||||
|
||||
|
||||
class ChunkFullError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||
return tensor.storage().size() == 0
|
||||
|
||||
|
||||
def free_storage(tensor: torch.Tensor) -> None:
|
||||
if not is_storage_empty(tensor):
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
|
||||
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
if is_storage_empty(tensor):
|
||||
tensor.storage().resize_(tensor.numel())
|
||||
|
||||
|
||||
class Chunk:
|
||||
"""
|
||||
A chunk is a contiguous memory space which contains multiple tensors.
|
||||
|
||||
Args:
|
||||
chunk_size (int): the number of elements in a chunk
|
||||
src_rank (int): the process which owns the chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
|
||||
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
src_rank: int,
|
||||
process_group: ColoProcessGroup,
|
||||
dtype: torch.dtype,
|
||||
init_device: Optional[torch.device] = None,
|
||||
force_data_on_cuda: bool = False) -> None:
|
||||
self.size = chunk_size
|
||||
self.utilized_size = 0
|
||||
self.src_rank = src_rank
|
||||
self.process_group = process_group
|
||||
self.is_src_rank = process_group.dp_local_rank() == src_rank
|
||||
self.global_src_rank = process_group.get_ranks_in_dp()[src_rank]
|
||||
self.dtype = dtype
|
||||
device = init_device or get_current_device()
|
||||
if force_data_on_cuda:
|
||||
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
|
||||
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
|
||||
if device.type == 'cuda':
|
||||
free_storage(self._cpu_data)
|
||||
else:
|
||||
free_storage(self.data)
|
||||
else:
|
||||
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
|
||||
self._cpu_data = None
|
||||
|
||||
# we only keep the chunk in full in the process by which the tensor is owned
|
||||
if not self.is_src_rank:
|
||||
free_storage(self._payload)
|
||||
|
||||
# each tensor is associated with a TensorInfo to track meta info
|
||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||
self.mem = self.size * self.data.element_size()
|
||||
|
||||
def append(self, tensor: torch.Tensor) -> None:
|
||||
"""
|
||||
Add a tensor to the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a tensor to be added to the chunk
|
||||
"""
|
||||
assert tensor.dtype == self.dtype
|
||||
new_utilized_size = self.utilized_size + tensor.numel()
|
||||
|
||||
# raise exception when the chunk size is exceeded
|
||||
if new_utilized_size > self.size:
|
||||
raise ChunkFullError
|
||||
|
||||
# set tensor state
|
||||
tensor_state = TensorState.FREE
|
||||
|
||||
# if the process owns the rank, then copy the tensor to its chunk buffer
|
||||
# otherwise set its storage size to 0 to reduce memory consumption
|
||||
if self.is_src_rank:
|
||||
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
|
||||
tensor_state = TensorState.HOLD
|
||||
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
||||
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
else:
|
||||
tensor.storage().resize_(0)
|
||||
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
||||
self.utilized_size = new_utilized_size
|
||||
|
||||
def release(self) -> None:
|
||||
"""
|
||||
Release the memory space on processes which do not own the chunk.
|
||||
"""
|
||||
if not self.is_src_rank:
|
||||
free_storage(self._payload)
|
||||
self._update_tensors_state(TensorState.FREE)
|
||||
|
||||
def _update_tensors_ptr(self) -> None:
|
||||
assert type(self._payload) == torch.Tensor
|
||||
for tensor, tensor_info in self.tensors_info.items():
|
||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if prev_state is None or tensor_info.state == prev_state:
|
||||
tensor_info.state = next_state
|
||||
|
||||
def access(self) -> None:
|
||||
"""
|
||||
Broadcast the chunk to synchronize the tensors across data parallel processes.
|
||||
"""
|
||||
# recover the chunk on non-owner processes
|
||||
# and broadcast the chunk from the source to all processes
|
||||
if not self.is_src_rank:
|
||||
alloc_storage(self._payload)
|
||||
self.move_device(get_current_device(), update_ptr=False)
|
||||
dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
|
||||
|
||||
# update tensor meta info
|
||||
self._update_tensors_ptr()
|
||||
if not self.is_src_rank:
|
||||
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
||||
|
||||
def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
|
||||
"""
|
||||
Move the chunk to a target device.
|
||||
|
||||
Args:
|
||||
device (torch.device): the target device for data movement.
|
||||
"""
|
||||
if self._payload.device == device:
|
||||
return
|
||||
if self._cpu_data is None:
|
||||
self.data.data = self.data.to(device)
|
||||
else:
|
||||
if device.type == 'cuda':
|
||||
# cpu -> cuda
|
||||
src = self._cpu_data
|
||||
dest = self.data
|
||||
else:
|
||||
# cuda -> cpu
|
||||
src = self.data
|
||||
dest = self._cpu_data
|
||||
alloc_storage(dest)
|
||||
dest.copy_(src)
|
||||
free_storage(src)
|
||||
|
||||
if update_ptr:
|
||||
self._update_tensors_ptr()
|
||||
|
||||
def reduce(self, is_all_reduce: bool = False) -> None:
|
||||
"""
|
||||
Reduce or all-reduce the chunk.
|
||||
|
||||
Args:
|
||||
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
|
||||
"""
|
||||
self.move_device(get_current_device(), update_ptr=False)
|
||||
if is_all_reduce:
|
||||
dist.all_reduce(self.data, group=self.process_group.dp_process_group())
|
||||
else:
|
||||
dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
|
||||
self._update_tensors_ptr()
|
||||
self._update_tensors_state(TensorState.HOLD)
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||
"""
|
||||
Make a transition of the tensor into the next state.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a torch Tensor object.
|
||||
tensor_state (TensorState): the target state for transition.
|
||||
"""
|
||||
|
||||
# As the gradient hook can be triggered either before or after post-backward
|
||||
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
||||
# or compute -> ready_for_reduce -> hold_after_bwd
|
||||
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
|
||||
# this function only apply valid state transformation
|
||||
# invalid calls will be ignored and nothing changes
|
||||
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
|
||||
# print(
|
||||
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
|
||||
# )
|
||||
return
|
||||
self.tensors_info[tensor].state = tensor_state
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
"""
|
||||
Copy data slice to the memory space indexed by the input tensor in the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
||||
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
||||
"""
|
||||
tensor_info = self.tensors_info[tensor]
|
||||
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
|
||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
"""
|
||||
Check whether the chunk can be released.
|
||||
"""
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if tensor_info.state != TensorState.HOLD:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_move_device(self) -> bool:
|
||||
"""
|
||||
Check whether the chunk can be moved across devices.
|
||||
"""
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_reduce(self) -> bool:
|
||||
"""
|
||||
Check whether the chunk can be reduced.
|
||||
"""
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if tensor_info.state != TensorState.READY_FOR_REDUCE:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check whether the chunk is empty.
|
||||
"""
|
||||
return is_storage_empty(self._payload)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
"""
|
||||
Check if the chunk has inf or nan values.
|
||||
"""
|
||||
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
|
||||
torch.isnan(self._payload[:self.utilized_size]).any().item()
|
||||
|
||||
def copy_(self, dest_chunk: 'Chunk'):
|
||||
"""
|
||||
Copy the data of this chunk to a destination chunk.
|
||||
"""
|
||||
assert not self.is_empty
|
||||
assert not dest_chunk.is_empty
|
||||
assert self.size == dest_chunk.size
|
||||
assert self.utilized_size == dest_chunk.utilized_size
|
||||
self._payload.copy_(dest_chunk._payload)
|
||||
self._update_tensors_ptr()
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
"""
|
||||
Get the device type of the chunk.
|
||||
"""
|
||||
return self._payload.device.type
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(id(self))
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return self is __o
|
||||
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
@property
|
||||
def _payload(self) -> torch.Tensor:
|
||||
if self._cpu_data is None or is_storage_empty(self._cpu_data):
|
||||
return self.data
|
||||
return self._cpu_data
|
|
@ -0,0 +1,3 @@
|
|||
from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk
|
||||
from .manager import ChunkManager
|
||||
from .search_utils import clasify_params, search_chunk_configuration
|
|
@ -1,14 +1,55 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkFullError, \
|
||||
free_storage, alloc_storage
|
||||
|
||||
|
||||
class ChunkV2:
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
COMPUTE = 1
|
||||
HOLD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
||||
TensorState.HOLD))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
state: TensorState
|
||||
offset: int
|
||||
end: int
|
||||
|
||||
|
||||
class ChunkFullError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||
return tensor.storage().size() == 0
|
||||
|
||||
|
||||
def free_storage(tensor: torch.Tensor) -> None:
|
||||
if not is_storage_empty(tensor):
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
|
||||
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
if is_storage_empty(tensor):
|
||||
tensor.storage().resize_(tensor.numel())
|
||||
|
||||
|
||||
class Chunk:
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
|
@ -19,18 +60,18 @@ class ChunkV2:
|
|||
pin_memory: bool = False) -> None:
|
||||
"""
|
||||
Chunk: A container owning a piece of contiguous memory space for tensors
|
||||
AgChunk is a kind of chunk, which uses all-gather operation to gather the whole chunk.
|
||||
This kind of chunk is exclusively used for DDP and ZeRO DDP.
|
||||
Here we use all-gather operation to gather the whole chunk.
|
||||
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
|
||||
It is designed to make the full use of communication and PCIE bandwidth.
|
||||
|
||||
Args:
|
||||
chunk_size (int): the number of elements in a chunk
|
||||
chunk_size (int): the number of elements in the chunk
|
||||
process_group (ColoProcessGroup): the process group of this chunk
|
||||
dtype (torch.dtype): the data type of the chunk
|
||||
init_device (torch.device): optional, the device where the tensor is initialized
|
||||
The default value is None, which is the current GPU
|
||||
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
|
||||
pin_memory (bool): optional, if True, this chunk always has a shard copy in pinned CPU memory
|
||||
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
|
||||
"""
|
||||
|
||||
self.chunk_size = chunk_size
|
||||
|
@ -42,7 +83,8 @@ class ChunkV2:
|
|||
self.pg_rank = dist.get_rank(self.torch_pg)
|
||||
|
||||
# the chunk size should be able to be divied by the size of GPU
|
||||
assert chunk_size % self.pg_size == 0
|
||||
if not keep_gathered:
|
||||
assert chunk_size % self.pg_size == 0
|
||||
self.shard_size = chunk_size // self.pg_size
|
||||
self.shard_begin = self.shard_size * self.pg_rank
|
||||
self.shard_end = self.shard_begin + self.shard_size
|
||||
|
@ -80,18 +122,15 @@ class ChunkV2:
|
|||
|
||||
# we introduce the paired chunk here
|
||||
# it refers to another chunk having the same parameters
|
||||
# but with different dtype(such as fp16_chunk.mapping_chunk -> fp32_chunk
|
||||
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
|
||||
self.paired_chunk = None
|
||||
# if the the gradient of this chunk is reduced, the flag is True
|
||||
# so the flag is False for unused parameters
|
||||
self.grad_reduced_flag = False
|
||||
# if this chunk is synchronized with the optimizer, the flag is True
|
||||
self.optim_sync_flag = True
|
||||
# if the cpu_shard has been visited during the training step, the flag is True
|
||||
self.cpu_vis_flag = False
|
||||
|
||||
@property
|
||||
def memory_usage(self):
|
||||
def memory_usage(self) -> Dict[str, int]:
|
||||
cuda_memory = 0
|
||||
cpu_memory = 0
|
||||
|
||||
|
@ -112,7 +151,7 @@ class ChunkV2:
|
|||
return dict(cuda=cuda_memory, cpu=cpu_memory)
|
||||
|
||||
@property
|
||||
def device_type(self):
|
||||
def device_type(self) -> str:
|
||||
if self.chunk_temp is not None:
|
||||
return self.chunk_temp.device.type
|
||||
else:
|
||||
|
@ -123,6 +162,56 @@ class ChunkV2:
|
|||
else:
|
||||
return 'cpu'
|
||||
|
||||
@property
|
||||
def payload(self) -> torch.Tensor:
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
if self.is_gathered:
|
||||
return self.chunk_total
|
||||
elif self.cuda_shard is not None:
|
||||
return self.cuda_shard
|
||||
else:
|
||||
return self.cpu_shard
|
||||
|
||||
@property
|
||||
def payload_mem(self) -> int:
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
if self.is_gathered:
|
||||
return self.chunk_mem
|
||||
else:
|
||||
return self.shard_mem
|
||||
|
||||
@property
|
||||
def can_move(self) -> bool:
|
||||
return not self.is_gathered
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
if self.keep_gathered:
|
||||
return False
|
||||
else:
|
||||
return self.tensors_state_monitor[TensorState.HOLD] + \
|
||||
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
"""Check if the chunk has inf or nan values in CUDA.
|
||||
"""
|
||||
if self.is_gathered:
|
||||
valid_tensor = self.chunk_total[:self.utilized_size]
|
||||
else:
|
||||
assert self.cuda_shard is not None # only check in CUDA
|
||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||
|
||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor):
|
||||
"""Add a tensor to the chunk.
|
||||
|
||||
|
@ -150,7 +239,10 @@ class ChunkV2:
|
|||
self.utilized_size = new_utilized_size
|
||||
|
||||
def close_chunk(self, shard_dev: Optional[torch.device] = None):
|
||||
"""Close the chunk. Any tensor can't be appended to a closed chunk.
|
||||
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
|
||||
|
||||
Args:
|
||||
shard_dev: the device where the shard locates
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is not None
|
||||
|
@ -163,6 +255,7 @@ class ChunkV2:
|
|||
|
||||
if self.chunk_temp.device.type == 'cpu':
|
||||
self.chunk_total = self.chunk_temp.to(get_current_device())
|
||||
self.__update_tensors_ptr()
|
||||
else:
|
||||
self.chunk_total = self.chunk_temp
|
||||
self.chunk_temp = None
|
||||
|
@ -186,6 +279,12 @@ class ChunkV2:
|
|||
self.cuda_shard = None
|
||||
|
||||
def shard_move(self, device: torch.device, force_copy: bool = False):
|
||||
"""Move the shard tensor in the chunk.
|
||||
|
||||
Args:
|
||||
device: the device to which the shard will move
|
||||
force_copy: if True, copy function is called mandatorily
|
||||
"""
|
||||
# sanity check
|
||||
assert not self.is_gathered
|
||||
# when the current chunk is not synchronized with the optimizer
|
||||
|
@ -223,8 +322,7 @@ class ChunkV2:
|
|||
raise NotImplementedError
|
||||
|
||||
def access_chunk(self):
|
||||
"""Make the chunk usable for the parameters inside it.
|
||||
It is an operation done in CUDA.
|
||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
@ -234,8 +332,7 @@ class ChunkV2:
|
|||
self.__update_tensors_ptr()
|
||||
|
||||
def release_chunk(self):
|
||||
"""Release the usable chunk.
|
||||
It is an operation done in CUDA.
|
||||
"""Release the usable chunk. It's an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
@ -244,8 +341,7 @@ class ChunkV2:
|
|||
self.__scatter()
|
||||
|
||||
def reduce(self):
|
||||
"""Reduce scatter all the gradients.
|
||||
It is an operation done in CUDA.
|
||||
"""Reduce scatter all the gradients. It's an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
|
@ -267,7 +363,6 @@ class ChunkV2:
|
|||
free_storage(self.chunk_total)
|
||||
self.is_gathered = False
|
||||
self.__update_tensors_state(TensorState.HOLD)
|
||||
self.grad_reduced_flag = True
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||
"""
|
||||
|
@ -285,9 +380,6 @@ class ChunkV2:
|
|||
# this function only apply valid state transformation
|
||||
# invalid calls will be ignored and nothing changes
|
||||
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
|
||||
# print(
|
||||
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
|
||||
# )
|
||||
return
|
||||
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
|
||||
|
||||
|
@ -306,46 +398,58 @@ class ChunkV2:
|
|||
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
|
||||
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||
|
||||
@property
|
||||
def can_move(self) -> bool:
|
||||
return not self.is_gathered
|
||||
|
||||
@property
|
||||
def can_release(self) -> bool:
|
||||
def get_valid_length(self) -> int:
|
||||
"""Get the valid length of the chunk's payload.
|
||||
"""
|
||||
if self.keep_gathered:
|
||||
return False
|
||||
return self.utilized_size
|
||||
else:
|
||||
return self.tensors_state_monitor[TensorState.HOLD] + \
|
||||
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
|
||||
return self.valid_end
|
||||
|
||||
@property
|
||||
def can_reduce(self):
|
||||
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
||||
|
||||
@property
|
||||
def has_inf_or_nan(self) -> bool:
|
||||
def init_pair(self, friend_chunk: 'Chunk') -> None:
|
||||
"""Initialize the paired chunk.
|
||||
"""
|
||||
Check if the chunk has inf or nan values in CUDA.
|
||||
"""
|
||||
if self.is_gathered:
|
||||
valid_tensor = self.chunk_total[:self.utilized_size]
|
||||
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
|
||||
self.paired_chunk = friend_chunk
|
||||
friend_chunk.paired_chunk = self
|
||||
else:
|
||||
assert self.cuda_shard is not None # only check in CUDA
|
||||
valid_tensor = self.cuda_shard[:self.valid_end]
|
||||
assert self.paired_chunk is friend_chunk
|
||||
assert friend_chunk.paired_chunk is self
|
||||
|
||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
def optim_update(self) -> None:
|
||||
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.paired_chunk is not None
|
||||
|
||||
friend_chunk = self.paired_chunk
|
||||
if self.is_gathered is True:
|
||||
assert friend_chunk.is_gathered is True
|
||||
self.chunk_total.copy_(friend_chunk.chunk_total)
|
||||
self.optim_sync_flag = True
|
||||
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
|
||||
self.cuda_shard.copy_(friend_chunk.cuda_shard)
|
||||
self.optim_sync_flag = True
|
||||
self.cpu_vis_flag = False
|
||||
else:
|
||||
# optim_sync_flag is set to False
|
||||
# see shard_move function for more details
|
||||
assert friend_chunk.device_type == 'cpu'
|
||||
assert self.device_type == 'cpu'
|
||||
self.optim_sync_flag = False
|
||||
self.cpu_vis_flag = False
|
||||
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
def __gather(self):
|
||||
if not self.is_gathered:
|
||||
# sanity check
|
||||
assert self.cuda_shard is not None
|
||||
|
||||
if self.pg_size == 1:
|
||||
self.chunk_total = self.cuda_shard
|
||||
else:
|
||||
alloc_storage(self.chunk_total)
|
||||
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
|
||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
||||
alloc_storage(self.chunk_total)
|
||||
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
|
||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
|
@ -404,9 +508,9 @@ class ChunkV2:
|
|||
def __eq__(self, __o: object) -> bool:
|
||||
return self is __o
|
||||
|
||||
def __repr__(self, detailed: bool = False):
|
||||
def __repr__(self, detailed: bool = True):
|
||||
output = [
|
||||
"AgChunk Information:\n",
|
||||
"Chunk Information:\n",
|
||||
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
|
||||
self.pg_size),
|
||||
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
|
||||
|
@ -442,6 +546,3 @@ class ChunkV2:
|
|||
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
|
||||
|
||||
return ''.join(output)
|
||||
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
|
@ -4,23 +4,19 @@ from collections import deque
|
|||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.gemini.chunk import ChunkFullError, TensorState
|
||||
from colossalai.gemini.update import ChunkV2 as Chunk
|
||||
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk
|
||||
|
||||
|
||||
class ChunkManagerV2:
|
||||
class ChunkManager:
|
||||
"""
|
||||
A manager class to manipulate the tensors in chunks.
|
||||
|
||||
Args:
|
||||
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
|
||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||
pin_memory (bool): if ture, all chunks have a piece of pinned memory in CPU.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration: Dict[int, Dict],
|
||||
init_device: Optional[torch.device] = None,
|
||||
pin_memory: bool = False) -> None:
|
||||
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
|
||||
|
||||
self.device = init_device or get_current_device()
|
||||
self.size_config: Dict[int, int] = dict()
|
||||
|
@ -28,7 +24,6 @@ class ChunkManagerV2:
|
|||
for k, v in self.kwargs_config.items():
|
||||
self.size_config[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
v['pin_memory'] = pin_memory
|
||||
|
||||
self.chunk_groups: Dict[str, Deque] = dict()
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||
|
@ -36,8 +31,14 @@ class ChunkManagerV2:
|
|||
self.lazy_release_tensors: List[torch.Tensor] = list()
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int) -> None:
|
||||
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
|
||||
"""Append a tensor to a chunk.
|
||||
|
||||
Args:
|
||||
tensor: the tensor appended to the chunk
|
||||
group_type: the data type of the group
|
||||
config_key: the key of the group's name, usually the size of the dp world
|
||||
pin_memory: whether the chunk is pinned in the cpu memory
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||
|
@ -66,7 +67,8 @@ class ChunkManagerV2:
|
|||
chunk_size=chunk_size,
|
||||
process_group=tensor.process_group,
|
||||
dtype=tensor.dtype,
|
||||
**chunk_kwargs
|
||||
pin_memory=pin_memory,
|
||||
**chunk_kwargs,
|
||||
)
|
||||
|
||||
chunk_group.append(chunk)
|
||||
|
@ -87,6 +89,8 @@ class ChunkManagerV2:
|
|||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
if chunk.device_type == 'cpu':
|
||||
chunk.shard_move(get_current_device())
|
||||
chunk.access_chunk()
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
self.accessed_chunks.add(chunk)
|
||||
|
@ -102,13 +106,13 @@ class ChunkManagerV2:
|
|||
self.__add_memory_usage(chunk.memory_usage)
|
||||
self.accessed_chunks.remove(chunk)
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
|
||||
"""Move the shard of the chunk to the target device.
|
||||
"""
|
||||
if not chunk.can_move or chunk.device_type == device.type:
|
||||
return
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.shard_move(device)
|
||||
chunk.shard_move(device, force_copy)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
|
@ -123,7 +127,7 @@ class ChunkManagerV2:
|
|||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.release_chunk()
|
||||
chunk.reduce()
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return True
|
||||
|
||||
|
@ -165,14 +169,14 @@ class ChunkManagerV2:
|
|||
self.release_chunk(chunk)
|
||||
self.lazy_release_tensors.clear()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = ['Chunk Manager Information:\n',
|
||||
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n']
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg.append(f'Group {group_name}:\n')
|
||||
for i, chunk in enumerate(group):
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
def get_cuda_movable_chunks(self, group_type: str) -> List[Chunk]:
|
||||
chunk_list = []
|
||||
for group_name in self.chunk_groups:
|
||||
if group_type in group_name:
|
||||
for chunk in self.chunk_groups[group_name]:
|
||||
if chunk.device_type == 'cuda' and chunk.can_move:
|
||||
chunk_list.append(chunk)
|
||||
return chunk_list
|
||||
|
||||
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||
"""
|
||||
|
@ -200,6 +204,17 @@ class ChunkManagerV2:
|
|||
assert tensor not in self.tensor_chunk_map
|
||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = [
|
||||
'Chunk Manager Information:\n',
|
||||
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
]
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg.append(f'Group {group_name}:\n')
|
||||
for i, chunk in enumerate(group):
|
||||
msg.append(f'[{i}] {chunk}\n')
|
||||
return ''.join(msg)
|
||||
|
||||
def __get_chunk_group(self, group_name: str) -> Deque:
|
||||
"""Register a chunk group.
|
||||
"""
|
||||
|
@ -208,8 +223,9 @@ class ChunkManagerV2:
|
|||
return self.chunk_groups[group_name]
|
||||
|
||||
def __close_one_chunk(self, chunk: Chunk):
|
||||
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
|
||||
self.__sub_memroy_usage(chunk.memory_usage)
|
||||
chunk.close_chunk(self.device)
|
||||
chunk.close_chunk(device)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
def __sub_memroy_usage(self, usage: Dict[str, int]):
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
@ -7,7 +8,7 @@ from colossalai.tensor import ColoParameter
|
|||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
||||
"""Filter those parameters whose size is too large from others.
|
||||
"""
|
||||
params_size = [p.numel() for p in model.parameters()]
|
||||
params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params_size_arr = np.array(params_size)
|
||||
|
||||
std = np.std(params_size_arr)
|
||||
|
@ -36,6 +37,9 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
|
|||
params_dict: Dict[int, List[ColoParameter]] = dict()
|
||||
for param in model.parameters():
|
||||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if getattr(param, '_ddp_to_ignore', False):
|
||||
continue
|
||||
|
||||
param_key = param.process_group.dp_world_size()
|
||||
|
||||
if param_key not in params_dict:
|
||||
|
@ -47,13 +51,13 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
|
|||
|
||||
def search_chunk_configuration(
|
||||
model: nn.Module,
|
||||
search_range_mb: int,
|
||||
search_range_mb: float,
|
||||
search_interval_byte: int, # hidden size is the best value for the interval
|
||||
min_chunk_size_mb: int = 32,
|
||||
filter_exlarge_params: bool = True):
|
||||
search_range_byte = search_range_mb * 1024**2
|
||||
min_chunk_size_byte = min_chunk_size_mb * 1024**2
|
||||
assert search_range_byte % search_interval_byte == 0
|
||||
min_chunk_size_mb: float = 32,
|
||||
filter_exlarge_params: bool = True) -> Dict:
|
||||
search_range_byte = round(search_range_mb * 1024**2)
|
||||
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
|
||||
assert search_range_byte >= 0
|
||||
|
||||
params_dict = clasify_params(model)
|
||||
config_dict: Dict[int, Dict] = dict()
|
||||
|
@ -75,11 +79,12 @@ def search_chunk_configuration(
|
|||
max_size = min_chunk_size_byte
|
||||
for key in size_dict:
|
||||
max_size = max(max_size, max(size_dict[key]))
|
||||
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
|
||||
|
||||
min_chunk_waste = float('+inf')
|
||||
best_chunk_size = max_size
|
||||
best_chunk_size = start_size
|
||||
|
||||
for chunk_size in range(max_size, max_size + search_range_byte + 1, search_interval_byte):
|
||||
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
|
||||
temp_waste = 0
|
||||
for key in size_dict:
|
||||
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
|
|
@ -1,344 +0,0 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||
from collections import deque
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup, ColoTensor
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
"""
|
||||
A manager class to manipulate the tensors in chunks.
|
||||
|
||||
Args:
|
||||
chunk_size (int): the size of a chunk.
|
||||
process_group (ColoProcessGroup): process group of the chunk.
|
||||
enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false.
|
||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
chunk_size: Optional[int],
|
||||
process_group: ColoProcessGroup,
|
||||
enable_distributed_storage: bool = False,
|
||||
init_device: Optional[torch.device] = None) -> None:
|
||||
assert chunk_size is None or chunk_size > 0
|
||||
assert isinstance(process_group, ColoProcessGroup)
|
||||
self.chunk_size = chunk_size
|
||||
self.process_group = process_group
|
||||
self.enable_distributed_storage = enable_distributed_storage
|
||||
self.device = init_device or get_current_device()
|
||||
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
|
||||
self.groups_force_data_on_cuda: Dict[str, bool] = {}
|
||||
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
|
||||
self.accessed_chunks: Set[Chunk] = set()
|
||||
self.lazy_release_tensors: List[torch.Tensor] = []
|
||||
if enable_distributed_storage and chunk_size is None:
|
||||
self.rank_load: Dict[str, torch.Tensor] = {}
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None:
|
||||
"""Create a chunk group.
|
||||
|
||||
Args:
|
||||
group_name (str): group name
|
||||
force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False.
|
||||
"""
|
||||
assert group_name not in self.chunk_groups
|
||||
self.chunk_groups[group_name] = deque()
|
||||
self.groups_force_data_on_cuda[group_name] = force_data_on_cuda
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
|
||||
"""
|
||||
Append a tensor to a chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a tensor to append to the chunk.
|
||||
group_name (str): the name of the chunk group.
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
if isinstance(tensor, ColoTensor):
|
||||
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
|
||||
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
|
||||
try:
|
||||
# append the tensor to the last chunk
|
||||
self.chunk_groups[group_name][-1].append(tensor)
|
||||
except (IndexError, ChunkFullError):
|
||||
# the except statement will be triggered when there is no chunk or
|
||||
# the last chunk in the chunk group is full
|
||||
# this will create a new chunk and allocate this chunk to its corresponding process
|
||||
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
|
||||
chunk_size = tensor.numel()
|
||||
else:
|
||||
chunk_size = self.chunk_size or tensor.numel()
|
||||
src_rank = self._get_next_src_rank(group_name)
|
||||
chunk = Chunk(chunk_size,
|
||||
src_rank,
|
||||
self.process_group,
|
||||
tensor.dtype,
|
||||
self.device,
|
||||
force_data_on_cuda=self.groups_force_data_on_cuda[group_name])
|
||||
|
||||
if self.enable_distributed_storage and self.chunk_size is None:
|
||||
self.rank_load[group_name][src_rank] += chunk_size
|
||||
|
||||
self.chunk_groups[group_name].append(chunk)
|
||||
chunk.append(tensor)
|
||||
if not chunk.is_empty:
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
|
||||
if not self.enable_distributed_storage:
|
||||
# as distributed storage is not enabled, there is no need to broadcast
|
||||
# chunks, thus we set these chunks as accessed
|
||||
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
|
||||
|
||||
def _get_next_src_rank(self, group_name: str) -> int:
|
||||
if not self.enable_distributed_storage:
|
||||
# the chunk is owned by the current rank if no distributed storage is enabled
|
||||
return self.process_group.dp_local_rank()
|
||||
if self.chunk_size is None:
|
||||
if group_name not in self.rank_load:
|
||||
self.rank_load[group_name] = torch.zeros(self.process_group.dp_world_size(), dtype=torch.int64)
|
||||
|
||||
# the process owning the tensor will be the process with the smallest number of elements
|
||||
src_rank = torch.argmin(self.rank_load[group_name]).item()
|
||||
else:
|
||||
# chunk is owned by processes in a round-robin fashion
|
||||
chunk_idx = len(self.chunk_groups[group_name])
|
||||
src_rank = chunk_idx % self.process_group.dp_world_size()
|
||||
return src_rank
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
"""
|
||||
Synchronize the chunks via broadcast.
|
||||
|
||||
Args:
|
||||
chunk (Chunk): the chunk to synchronize.
|
||||
"""
|
||||
if chunk in self.accessed_chunks:
|
||||
if chunk.device_type != 'cuda':
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.move_device(get_current_device())
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
return
|
||||
if not chunk.is_empty:
|
||||
# as tensor is moved to the target device
|
||||
# the memory consumption of the original device is reduced
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.access()
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""
|
||||
Release the memory space of a chunk.
|
||||
|
||||
Args:
|
||||
chunk (Chunk): the chunk to release memory space
|
||||
"""
|
||||
|
||||
if not self.enable_distributed_storage:
|
||||
return
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
chunk.release()
|
||||
self.accessed_chunks.remove(chunk)
|
||||
if chunk.is_empty:
|
||||
# update the memory consumption after releasing
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None:
|
||||
"""
|
||||
Move the chunk to the target device.
|
||||
|
||||
Args:
|
||||
chunk (Chunk): the chunk to move to target device
|
||||
device (torch.device): target device
|
||||
"""
|
||||
if chunk.device_type == device.type:
|
||||
return
|
||||
if chunk.can_move_device and not chunk.is_empty:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.move_device(device, update_ptr=update_ptr)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||
"""
|
||||
Transit tensor state according to pre-defined state machine.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor for state transititon
|
||||
state (TensorState): next tensor state for transtition
|
||||
"""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
"""
|
||||
Reduce or all reduce the chunk. If enable_distributed_storage is true, all-reduce is used.
|
||||
Otherwise, this method uses reduce.
|
||||
|
||||
Args:
|
||||
chunk (Chunk): the chunk for reduction.
|
||||
"""
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
return True
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||
"""
|
||||
Copy data to the chunk.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
||||
data (torch.Tensor): the tensor to be copied to the chunk
|
||||
"""
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
||||
|
||||
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||
"""
|
||||
Return the chunk owning the tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): a torch tensor object
|
||||
"""
|
||||
return self.tensor_chunk_map[tensor]
|
||||
|
||||
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
|
||||
"""
|
||||
Add tensors to the buffer for lazy release.
|
||||
|
||||
Args:
|
||||
tensors (List[torch.Tensor]): the tensors to be released lazily
|
||||
"""
|
||||
self.lazy_release_tensors.extend(tensors)
|
||||
|
||||
def exec_lazy_release(self) -> None:
|
||||
"""
|
||||
Execute release for tensors added to the lazy release buffer.
|
||||
"""
|
||||
|
||||
for chunk in self.get_chunks(self.lazy_release_tensors):
|
||||
self.release_chunk(chunk)
|
||||
self.lazy_release_tensors.clear()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = f'Rank {self.process_group.dp_local_rank()}:\n'
|
||||
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
|
||||
for group_name, group in self.chunk_groups.items():
|
||||
msg += f'Group {group_name}:\n'
|
||||
for i, chunk in enumerate(group):
|
||||
msg += f'[{i}] {chunk}\n'
|
||||
return msg
|
||||
|
||||
@staticmethod
|
||||
def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float:
|
||||
"""
|
||||
Calculate the utilization rate of a chunk.
|
||||
|
||||
Args:
|
||||
chunk_size (int): the size of a chunk
|
||||
params_numel (List[int]): the list of integers representing the number of elements of parameters
|
||||
"""
|
||||
assert len(params_numel) > 0
|
||||
total_size = 0
|
||||
total_utilized_size = 0
|
||||
cur_chunk_utilized_size = 0
|
||||
for size in params_numel:
|
||||
assert chunk_size >= size
|
||||
total_utilized_size += size
|
||||
if total_size == 0 or cur_chunk_utilized_size + size > chunk_size:
|
||||
total_size += chunk_size
|
||||
cur_chunk_utilized_size = 0
|
||||
cur_chunk_utilized_size += size
|
||||
return total_utilized_size / total_size
|
||||
|
||||
@staticmethod
|
||||
def search_chunk_size(module: torch.nn.Module,
|
||||
search_range: int,
|
||||
n_grids: int,
|
||||
min_chunk_size: Optional[int] = None,
|
||||
filter_exlarge_params: bool = True) -> int:
|
||||
"""
|
||||
Search for the chunk size for optimal chunk utilization.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): a torch module object
|
||||
search_range (int): the range of chunk size to search. The actual search range will be from
|
||||
max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + search_range.
|
||||
n_grids (int): the number of intervals in the search range
|
||||
min_chunk_size (int): optional, the minimum size for a chunk. The default is None.
|
||||
|
||||
"""
|
||||
assert search_range % n_grids == 0
|
||||
# TODO(ver217): sort params and filter unused ones
|
||||
params_numel = [p.numel() for p in module.parameters()]
|
||||
if filter_exlarge_params:
|
||||
params_numel = _filter_exlarge_params(params_numel)
|
||||
max_param_numel = max(params_numel)
|
||||
if min_chunk_size is not None:
|
||||
assert min_chunk_size >= max_param_numel
|
||||
else:
|
||||
min_chunk_size = max_param_numel
|
||||
step_size = search_range // n_grids
|
||||
max_chunk_util = -1
|
||||
best_chunk_size = -1
|
||||
for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size):
|
||||
chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel)
|
||||
if chunk_util > max_chunk_util:
|
||||
max_chunk_util = chunk_util
|
||||
best_chunk_size = chunk_size
|
||||
return best_chunk_size
|
||||
|
||||
def copy_chunk_group(self, dest_group_name: str, src_group_name: str):
|
||||
"""
|
||||
Copy chunk data from one group to another group.
|
||||
|
||||
Args:
|
||||
dest_group_name (str): the destination group which receives the copied data
|
||||
src_group_name (str): the source group which provides the data to copy
|
||||
"""
|
||||
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
|
||||
if not dest_chunk.is_empty:
|
||||
dest_chunk.copy_(src_chunk)
|
||||
|
||||
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||
"""
|
||||
Get all chunks owning the input tensors.
|
||||
|
||||
Args:
|
||||
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
|
||||
"""
|
||||
chunks = []
|
||||
for tensor in tensors:
|
||||
chunk = self.get_chunk(tensor)
|
||||
if chunk not in chunks:
|
||||
chunks.append(chunk)
|
||||
return tuple(chunks)
|
||||
|
||||
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
|
||||
"""Add extern static tensor to chunk manager.
|
||||
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
|
||||
They are "static", which means their shape, dtype, device never change.
|
||||
Thus, their memory usage never changes.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||
|
||||
|
||||
def _filter_exlarge_params(params_numel: List[int]) -> List[int]:
|
||||
params_numel_arr = np.array(params_numel)
|
||||
std = np.std(params_numel_arr)
|
||||
mean = np.mean(params_numel_arr)
|
||||
upper_limit = mean + 3 * std
|
||||
return list(filter(lambda x: x <= upper_limit, params_numel))
|
|
@ -3,7 +3,7 @@ import functools
|
|||
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||
from typing import List, Optional, Tuple
|
||||
from time import time
|
||||
from colossalai.gemini import Chunk, ChunkManager
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
from .placement_policy import PlacementPolicyFactory
|
||||
|
||||
|
||||
|
@ -56,37 +56,44 @@ class GeminiManager:
|
|||
self._evict_time = 0
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None:
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_type: str) -> None:
|
||||
""" Adjust the layout of statefuil tensor according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
start = time()
|
||||
self._record_chunks_order(chunks)
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_name)
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_type)
|
||||
self._layout_time += time() - start
|
||||
vol, evict_time = self._placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||
|
||||
vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
|
||||
self._d2h_volume += vol
|
||||
self._evict_time += evict_time
|
||||
# move COMPUTE tensors to CUDA
|
||||
self._h2d_volume += cuda_demand
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_type: str):
|
||||
start = time()
|
||||
cuda_demand = 0
|
||||
for chunk in chunks:
|
||||
if chunk.device_type == 'cpu' or chunk.is_empty:
|
||||
cuda_demand += chunk.mem
|
||||
if chunk.device_type == 'cuda':
|
||||
if chunk.is_gathered:
|
||||
pass
|
||||
else:
|
||||
cuda_demand += chunk.chunk_mem - chunk.shard_mem
|
||||
elif chunk.device_type == 'cpu':
|
||||
cuda_demand += chunk.chunk_mem
|
||||
else:
|
||||
raise RuntimeError
|
||||
self._comp_cuda_demand_time += time() - start
|
||||
can_evict_chunks = []
|
||||
for chunk in self._chunk_manager.chunk_groups[group_name]:
|
||||
if not chunk.is_empty and chunk.device_type == 'cuda' and chunk.can_move_device:
|
||||
can_evict_chunks.append(chunk)
|
||||
|
||||
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks(group_type)
|
||||
return cuda_demand, can_evict_chunks
|
||||
|
||||
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
|
|
|
@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
|||
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.gemini import ChunkManager
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
|
||||
import torch
|
||||
import time
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity
|
|||
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||
from typing import Type
|
||||
import functools
|
||||
from colossalai.gemini import Chunk, ChunkManager
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
|
||||
|
||||
class PlacementPolicy(ABC):
|
||||
|
@ -19,7 +19,7 @@ class PlacementPolicy(ABC):
|
|||
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector
|
||||
|
||||
@abstractmethod
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> None:
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
|
@ -32,12 +32,12 @@ class CPUPlacementPolicy(PlacementPolicy):
|
|||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
volume = 0
|
||||
start = time()
|
||||
for chunk in can_evict_chunks:
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
|
||||
volume += chunk.mem
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
volume += chunk.shard_mem
|
||||
return volume, time() - start
|
||||
|
||||
|
||||
|
@ -47,7 +47,7 @@ class CUDAPlacementPolicy(PlacementPolicy):
|
|||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
return 0, 0
|
||||
|
||||
@staticmethod
|
||||
|
@ -59,7 +59,8 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
|
||||
need_mem_stats: bool = True
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
_warmup_non_model_data_ratio: float = 0.8
|
||||
_steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
|
@ -70,14 +71,14 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
can_evict_chunks: List[Chunk],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: List[Tuple[Chunk, ...]] = [],
|
||||
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> int:
|
||||
**kwargs) -> Tuple[int, float]:
|
||||
"""
|
||||
Evict tensors from CUDA device.
|
||||
|
||||
Args:
|
||||
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
|
||||
can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted.
|
||||
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
|
||||
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
|
||||
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
|
||||
|
@ -114,12 +115,12 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
for chunk in to_free_chunks:
|
||||
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += chunk.mem
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
|
||||
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
freed_cuda_model_data += chunk.shard_mem
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
|
||||
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}")
|
||||
return freed_cuda_model_data, time() - start
|
||||
|
||||
@staticmethod
|
||||
|
@ -147,7 +148,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, PlacementPolicy] = {
|
||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||
'cpu': CPUPlacementPolicy,
|
||||
'cuda': CUDAPlacementPolicy,
|
||||
'auto': AutoPlacementPolicy
|
||||
|
|
|
@ -1,131 +0,0 @@
|
|||
import queue
|
||||
import heapq
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, Dict
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
def evict_check(st: StatefulTensor) -> bool:
|
||||
if st.state is not TensorState.COMPUTE and st.device.type == 'cuda':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Here ST means Stateful Tensor
|
||||
class BaseSTContainer(ABC):
|
||||
"""A type of container that store all potential stateful tensors which can be evicted from
|
||||
CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been
|
||||
evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA
|
||||
memory, meaning its state isn't COMPUTE.
|
||||
|
||||
This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE.
|
||||
And it pops stateful tensors in function, `evict_tensors`.
|
||||
|
||||
In order to acquire an optimal eviction policy, users may need to offer computation step
|
||||
index of each stateful tensor. So we can use a heap to maintain all potential evictable
|
||||
statefule tensors. When poping, we can get the stateful tensor that used furthest in
|
||||
current computation step.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||
self.compute_step_dict = compute_step_dict
|
||||
self.total_step = total_step
|
||||
|
||||
@abstractmethod
|
||||
def empty(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pop(self) -> Optional[StatefulTensor]:
|
||||
pass
|
||||
|
||||
|
||||
class QueueSTContainer(BaseSTContainer):
|
||||
"""Queue type stateful tensor container. This is used in 'cpu' tensor placement policy.
|
||||
It pops potential evictable stateful tensors in FIFO.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||
super().__init__(compute_step_dict, total_step)
|
||||
self.container = None
|
||||
|
||||
def empty(self) -> bool:
|
||||
assert self.container is not None
|
||||
return self.container.empty()
|
||||
|
||||
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||
self.container = queue.SimpleQueue()
|
||||
for stateful_tensor in stateful_tensor_list:
|
||||
self.container.put(stateful_tensor)
|
||||
|
||||
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||
self.container.put(stateful_tensor)
|
||||
|
||||
def pop(self) -> Optional[StatefulTensor]:
|
||||
ret = None
|
||||
while not self.empty():
|
||||
out_tensor = self.container.get()
|
||||
if evict_check(out_tensor):
|
||||
ret = out_tensor
|
||||
break
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class HeapSTContainer(BaseSTContainer):
|
||||
"""Heap type stateful tensor container. This is used in 'auto' tensor placement policy.
|
||||
It pops potential evictable stateful tensors in the order of the distance between current
|
||||
step and next used step.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
|
||||
super().__init__(compute_step_dict, total_step)
|
||||
self.container = None
|
||||
|
||||
def empty(self) -> bool:
|
||||
assert self.container is not None
|
||||
return self.container == []
|
||||
|
||||
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
|
||||
self.container = []
|
||||
for stateful_tensor in stateful_tensor_list:
|
||||
# we want to pop the tensor which has the greatest next_step
|
||||
# so the weight is next_step multiplied by -1
|
||||
weight = -self.__get_next_compute_step(stateful_tensor, -1)
|
||||
self.container.append((weight, stateful_tensor))
|
||||
heapq.heapify(self.container)
|
||||
|
||||
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
|
||||
# we want to pop the tensor which has the greatest next_step
|
||||
# so the weight is next_step multiplied by -1
|
||||
weight = -self.__get_next_compute_step(stateful_tensor, cur_step)
|
||||
heapq.heappush(self.container, (weight, stateful_tensor))
|
||||
|
||||
def pop(self) -> Optional[StatefulTensor]:
|
||||
ret = None
|
||||
while not self.empty():
|
||||
_, out_tensor = heapq.heappop(self.container)
|
||||
if evict_check(out_tensor):
|
||||
ret = out_tensor
|
||||
break
|
||||
return ret
|
||||
|
||||
def __get_next_compute_step(self, stateful_tensor: StatefulTensor, cur_step: int):
|
||||
# compute the id of next step
|
||||
# if the tensor is not used in the furture
|
||||
# next_step is set to the maximum
|
||||
next_step = self.total_step
|
||||
step_list = self.compute_step_dict[stateful_tensor]
|
||||
for step in step_list:
|
||||
if step > cur_step:
|
||||
next_step = step
|
||||
break
|
||||
return next_step
|
|
@ -1,3 +0,0 @@
|
|||
from .chunkv2 import ChunkV2
|
||||
from .chunk_mgrv2 import ChunkManagerV2
|
||||
from .search_utils import clasify_params, search_chunk_configuration
|
|
@ -3,16 +3,18 @@ import itertools
|
|||
import torch.distributed as dist
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.gemini.chunk import TensorState, Chunk
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from .reducer import Reducer
|
||||
|
||||
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
|
@ -208,29 +210,41 @@ class ZeroDDP(ColoDDP):
|
|||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False) -> None:
|
||||
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
self.fp32_params: List[ColoParameter] = []
|
||||
self.fp32_params: List[ColoTensor] = []
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True)
|
||||
self.chunk_manager.create_group('fp32_param')
|
||||
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
assert isinstance(p, ColoParameter)
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
p.data = p.half()
|
||||
continue
|
||||
fp32_p = p.float().detach()
|
||||
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
fp32_data = p.float().data
|
||||
p.data = p.half()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
self.chunk_manager.close_all_groups()
|
||||
self._cast_buffers()
|
||||
|
||||
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
for p, fp32_p in zip(params_list, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
chunk_32.init_pair(chunk_16)
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
@ -248,10 +262,7 @@ class ZeroDDP(ColoDDP):
|
|||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
|
||||
p.grad = None
|
||||
else:
|
||||
p.grad = p.data
|
||||
p.grad = None
|
||||
|
||||
def _post_backward(self):
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
|
@ -276,21 +287,22 @@ class ZeroDDP(ColoDDP):
|
|||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
if self.dp_world_size > 1:
|
||||
grad = grad / self.dp_world_size
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
chunk.copy_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if reduced and not chunk.is_empty:
|
||||
if reduced:
|
||||
if chunk.is_gathered:
|
||||
chunk.chunk_total.div_(chunk.pg_size)
|
||||
else:
|
||||
chunk.cuda_shard.div_(chunk.pg_size)
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
||||
|
@ -311,14 +323,11 @@ class ZeroDDP(ColoDDP):
|
|||
['bias', 'weight']
|
||||
|
||||
"""
|
||||
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
|
||||
record_flag = (not only_rank_0) or is_rank_0
|
||||
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, record_flag)
|
||||
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
|
||||
|
||||
for hook in self._state_dict_hooks.values():
|
||||
hook_result = hook(self, destination, prefix, local_metadata)
|
||||
|
@ -326,7 +335,7 @@ class ZeroDDP(ColoDDP):
|
|||
destination = hook_result
|
||||
return destination
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True):
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
|
@ -339,30 +348,30 @@ class ZeroDDP(ColoDDP):
|
|||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
|
||||
|
||||
# save parameters
|
||||
param_to_save_data = dict()
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunk_list:
|
||||
# record the original device of the chunk
|
||||
org_chunk_dev_typ = chunk.device_type
|
||||
self.chunk_manager.access_chunk(chunk)
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor in chunk.get_tensors():
|
||||
rec_p = torch.empty([0])
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
record_tensor = torch.empty([0])
|
||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||
if record_flag:
|
||||
rec_p = tensor.cpu() # move the whole tensor to CPU mem
|
||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||
|
||||
assert tensor not in param_to_save_data
|
||||
param_to_save_data[tensor] = rec_p
|
||||
# release the actual memory of the chunk
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if not chunk.is_empty and org_chunk_dev_typ == 'cpu':
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
param_to_save_data[tensor] = record_tensor
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
|
||||
rec_p = param_to_save_data[fp32_p]
|
||||
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
|
||||
record_parameter = param_to_save_data[fp32_p]
|
||||
destination[prefix + name] = record_parameter
|
||||
|
||||
# save all buffers
|
||||
for name, buf in self.named_buffers():
|
||||
|
@ -466,40 +475,61 @@ class ZeroDDP(ColoDDP):
|
|||
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
def load(name, dest_tensor, copy_func):
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
def load(param_name, dest_tensor, copy_func):
|
||||
state_key = prefix + param_name
|
||||
if state_key in state_dict:
|
||||
input_param = state_dict[state_key]
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
if input_param.shape != dest_tensor.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(key, input_param.shape,
|
||||
'the shape in current model is {}.'.format(state_key, input_param.shape,
|
||||
dest_tensor.shape))
|
||||
return
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param)
|
||||
copy_func(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(key, dest_tensor.size(), input_param.size(),
|
||||
ex.args))
|
||||
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
|
||||
input_param.size(), ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
missing_keys.append(state_key)
|
||||
|
||||
def load_fp32_p(fp32_p, data):
|
||||
if fp32_p.storage().size() > 0:
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, data)
|
||||
def load_fp32_parameter(chunk_slice, data):
|
||||
chunk_slice.copy_(data.flatten())
|
||||
|
||||
fp32_to_name = dict()
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
load(name, fp32_p, partial(load_fp32_p, fp32_p))
|
||||
self.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
|
||||
fp32_to_name[fp32_p] = name
|
||||
|
||||
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
for chunk in chunk_list:
|
||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||
|
||||
for tensor, tensor_info in chunk.tensors_info.items():
|
||||
parameter_name = fp32_to_name[tensor]
|
||||
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
|
||||
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
|
||||
|
||||
if chunk.is_gathered:
|
||||
chunk.chunk_total.copy_(temp_chunk)
|
||||
elif chunk.cuda_shard is not None:
|
||||
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
else:
|
||||
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
|
||||
|
||||
del temp_chunk
|
||||
|
||||
for chunk_32 in chunk_list:
|
||||
chunk_16 = chunk_32.paired_chunk
|
||||
assert chunk_16 is not None
|
||||
chunk_16.optim_update()
|
||||
|
||||
for name, buf in persistent_buffers.items():
|
||||
if buf is not None:
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_temp_total_chunk_on_cuda(chunk: Chunk):
|
||||
if chunk.is_gathered:
|
||||
return chunk.chunk_total
|
||||
|
||||
if chunk.cuda_shard is not None:
|
||||
shard_temp = chunk.cuda_shard
|
||||
else:
|
||||
shard_temp = chunk.cpu_shard.to(get_current_device())
|
||||
|
||||
total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device())
|
||||
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
|
||||
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
|
||||
|
||||
return total_temp
|
|
@ -54,8 +54,8 @@ class ZeROHookV2(ParamOpHook):
|
|||
|
||||
@contextmanager
|
||||
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
||||
old_training_phase = self._training_phase
|
||||
try:
|
||||
old_training_phase = self._training_phase
|
||||
self._training_phase = training_phase
|
||||
yield
|
||||
finally:
|
||||
|
|
|
@ -2,17 +2,14 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from enum import Enum
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn import Parameter
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from typing import Dict
|
||||
from typing import Dict, Tuple, Set
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm
|
||||
from collections import defaultdict, abc as container_abcs
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from torch._six import inf
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -33,8 +30,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
Args:
|
||||
optim (Optimizer): An Optimizer instance.
|
||||
module (ZeroDDP): A ``ZeroDDP`` instance.
|
||||
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
||||
which will be used when using hybrid CPU optimizer.
|
||||
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
||||
which will be used when using hybrid CPU optimizer.
|
||||
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
|
||||
Defaults to 0.0.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
|
@ -61,11 +58,19 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
assert isinstance(module, ZeroDDP)
|
||||
self.module = module
|
||||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager = self.gemini_manager.chunk_manager
|
||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
|
||||
for p, fp32_p in zip(module.parameters(), module.fp32_params):
|
||||
self.fp16_param_to_fp32_param[p] = fp32_p
|
||||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
||||
self.chunk16_set: Set[Chunk] = set()
|
||||
|
||||
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
for p, fp32_p in zip(params_list, module.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
if chunk_16 not in self.chunk16_set:
|
||||
self.chunk16_set.add(chunk_16)
|
||||
|
||||
self.__init__optimizer()
|
||||
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
|
@ -75,7 +80,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
|
@ -90,16 +95,26 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
self._register_states = disposable(self._register_states_)
|
||||
|
||||
def _update_params_ptr(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if not self.module.chunk_manager.get_chunk(p).is_empty:
|
||||
p.data = self.fp16_param_to_fp32_param[p]
|
||||
else:
|
||||
assert p.grad is None
|
||||
def _set_grad_ptr(self):
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
begin, end = self.param_to_range[fake_param]
|
||||
chunk16 = chunk32.paired_chunk
|
||||
|
||||
fake_param.data = chunk16.payload[begin:end]
|
||||
fake_param.grad = fake_param.data
|
||||
fake_param.data = chunk32.payload[begin:end]
|
||||
|
||||
def _update_fp16_params(self):
|
||||
self.module.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
|
||||
none_tensor = torch.empty([0])
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
assert fake_param.grad is None
|
||||
fake_param.data = none_tensor
|
||||
|
||||
for chunk16 in self.chunk16_set:
|
||||
chunk16.optim_update()
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
|
@ -128,6 +143,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
def step(self, *args, **kwargs):
|
||||
self._maybe_move_fp32_params()
|
||||
self._set_grad_ptr()
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
|
@ -138,45 +154,14 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self.zero_grad()
|
||||
self._update_fp16_params()
|
||||
return
|
||||
self._update_params_ptr()
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
self._register_states()
|
||||
self.zero_grad()
|
||||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
def compute_grad_norm(self, norm_type: float = 2.0) -> float:
|
||||
norm_type = float(norm_type)
|
||||
if not self.chunk_manager.enable_distributed_storage:
|
||||
return compute_grad_norm(self.module.parameters(), norm_type)
|
||||
|
||||
non_distributed_params = []
|
||||
distributed_params = []
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
non_distributed_params.append(p)
|
||||
else:
|
||||
distributed_params.append(p)
|
||||
non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type)
|
||||
distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)],
|
||||
device=get_current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(distributed_norm_tensor,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=self.chunk_manager.process_group.dp_process_group())
|
||||
total_norm = max(non_distributed_norm, distributed_norm_tensor.item())
|
||||
else:
|
||||
dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group())
|
||||
total_norm = non_distributed_norm + distributed_norm_tensor.item()
|
||||
total_norm = total_norm**(1 / norm_type)
|
||||
return total_norm
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
total_norm = self.compute_grad_norm(norm_type)
|
||||
_clip_grad_norm(self.module.parameters(), max_norm, total_norm)
|
||||
return total_norm
|
||||
raise NotImplementedError
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
|
@ -197,24 +182,31 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
|
||||
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
||||
fp32_params_used_cuda_margin_mem = 0
|
||||
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
|
||||
self.chunk_manager.chunk_groups['fp32_param']):
|
||||
if fp32_param_chunk.is_empty:
|
||||
continue
|
||||
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
|
||||
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
|
||||
# stores grad now
|
||||
self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device())
|
||||
self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device())
|
||||
fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem
|
||||
for p in fp16_param_chunk.get_tensors():
|
||||
state = self.optim.state[p]
|
||||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
chunk16 = chunk32.paired_chunk
|
||||
|
||||
if chunk32.device_type == 'cuda':
|
||||
continue
|
||||
|
||||
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
||||
self.chunk_manager.move_chunk(chunk32, get_current_device())
|
||||
# stores grad now
|
||||
self.chunk_manager.move_chunk(chunk16, get_current_device())
|
||||
self.module.set_chunk_grad_device(chunk16, get_current_device())
|
||||
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
|
||||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
chunk32 = self.param_to_chunk32[fake_param]
|
||||
if chunk32.device_type == 'cuda':
|
||||
state = self.optim.state[fake_param]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(get_current_device())
|
||||
|
||||
self.module._setup_grads_ptr()
|
||||
|
||||
def _register_states_(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
|
@ -223,110 +215,27 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
if isinstance(val, torch.Tensor):
|
||||
self.chunk_manager.add_extern_static_tensor(val)
|
||||
|
||||
def state_dict(self, only_rank_0: bool = True):
|
||||
r"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None.
|
||||
This saves memory usage.
|
||||
def __init__optimizer(self):
|
||||
|
||||
It contains two entries:
|
||||
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
||||
param_info = local_chunk.tensors_info[local_param]
|
||||
begin = max(0, param_info.offset - local_chunk.shard_begin)
|
||||
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
||||
return begin, end
|
||||
|
||||
* state - a dict holding current optimization state. Its content
|
||||
differs between optimizer classes.
|
||||
* param_groups - a list containing all parameter groups where each
|
||||
parameter group is a dict
|
||||
"""
|
||||
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
|
||||
if not self.chunk_manager.enable_distributed_storage and only_rank_0 and not is_rank_0:
|
||||
return
|
||||
optim_state_dict = super().state_dict()
|
||||
scaler_state_dict = self.grad_scaler.state_dict()
|
||||
optim_state_dict['scaler'] = scaler_state_dict
|
||||
if not self.chunk_manager.enable_distributed_storage:
|
||||
return optim_state_dict
|
||||
local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0}
|
||||
if not self.chunk_manager.process_group.has_cpu_groups:
|
||||
self.chunk_manager.process_group.set_cpu_groups()
|
||||
output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())]
|
||||
if only_rank_0:
|
||||
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
|
||||
dist.gather_object(local_state,
|
||||
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
|
||||
dst=dst_rank,
|
||||
group=self.chunk_manager.process_group.cpu_dp_process_group())
|
||||
if not is_rank_0:
|
||||
return
|
||||
else:
|
||||
dist.all_gather_object(output, local_state, group=self.chunk_manager.process_group.cpu_dp_process_group())
|
||||
for state in output:
|
||||
optim_state_dict['state'].update(state)
|
||||
return optim_state_dict
|
||||
for group in self.optim.param_groups:
|
||||
fake_params_list = list()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""Loads the optimizer state.
|
||||
for param in group['params']:
|
||||
chunk16 = self.chunk_manager.get_chunk(param)
|
||||
range_pair = get_range_pair(chunk16, param)
|
||||
if range_pair[0] >= range_pair[1]:
|
||||
continue
|
||||
|
||||
Args:
|
||||
state_dict (dict): optimizer state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
if 'scaler' not in state_dict:
|
||||
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
|
||||
else:
|
||||
self.grad_scaler.load_state_dict(deepcopy(state_dict['scaler']))
|
||||
fake_param = torch.nn.Parameter(torch.empty([0]))
|
||||
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
|
||||
self.param_to_range[fake_param] = range_pair
|
||||
|
||||
# Validate the state_dict
|
||||
groups = self.param_groups
|
||||
saved_groups = deepcopy(state_dict['param_groups'])
|
||||
fake_params_list.append(fake_param)
|
||||
|
||||
if len(groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of "
|
||||
"parameter groups")
|
||||
param_lens = (len(g['params']) for g in groups)
|
||||
saved_lens = (len(g['params']) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
raise ValueError("loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Update the state
|
||||
id_map = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in groups)))
|
||||
}
|
||||
|
||||
def cast(param, value):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v) for k, v in value.items()}
|
||||
elif isinstance(value, container_abcs.Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
state = defaultdict(dict)
|
||||
for k, v in state_dict['state'].items():
|
||||
if k in id_map:
|
||||
param = self.fp16_param_to_fp32_param[id_map[k]]
|
||||
if param.storage().size() > 0:
|
||||
state[param] = cast(param, deepcopy(v))
|
||||
else:
|
||||
state[k] = deepcopy(v)
|
||||
|
||||
# Update parameter groups, setting their 'params' value
|
||||
def update_group(group, new_group):
|
||||
new_group['params'] = group['params']
|
||||
return new_group
|
||||
|
||||
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
||||
self.__setstate__({'state': state, 'param_groups': param_groups})
|
||||
|
||||
|
||||
def convert_state_dict_to_cpu(state: Dict[str, torch.Tensor]):
|
||||
return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in state.items()}
|
||||
group['params'] = fake_params_list
|
||||
|
|
|
@ -6,11 +6,11 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.gemini import ChunkManager
|
||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from functools import partial
|
||||
from colossalai.nn.parallel import ColoDDP, ZeroDDP
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Callable
|
||||
from typing import Callable, Type
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
import random
|
||||
|
@ -32,10 +32,9 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
|||
return ColoDDP(module, process_group=pg)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
|
||||
pg = ProcessGroup()
|
||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, pg)
|
||||
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
|
||||
chunk_config = search_chunk_configuration(module, 4, 1024)
|
||||
chunk_manager = ChunkManager(chunk_config)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ZeroDDP(module, gemini_manager)
|
||||
|
||||
|
@ -51,7 +50,7 @@ class Net(torch.nn.Module):
|
|||
return self.fc2(self.fc1(x))
|
||||
|
||||
|
||||
def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
|
||||
def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = Net().cuda()
|
||||
w1 = model.fc1.weight
|
||||
|
@ -62,8 +61,14 @@ def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], Col
|
|||
logits = model(x)
|
||||
loss = torch.sum(logits)
|
||||
model.backward(loss)
|
||||
|
||||
if ddp_cls is ZeroDDP:
|
||||
w1s_grad = w1
|
||||
else:
|
||||
w1s_grad = w1.grad
|
||||
|
||||
w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(w1_grads, w1.grad)
|
||||
dist.all_gather(w1_grads, w1s_grad)
|
||||
assert torch.equal(w1_grads[0], w1_grads[1])
|
||||
w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(w2_grads, w2.grad)
|
||||
|
@ -74,8 +79,7 @@ def run_dist(rank, world_size, port):
|
|||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
set_seed(dist.get_rank())
|
||||
run_fwd_bwd(ColoDDP, init_ddp)
|
||||
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False))
|
||||
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True))
|
||||
run_fwd_bwd(ZeroDDP, init_ddpv2)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -8,14 +8,11 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.gemini import ChunkManager
|
||||
from functools import partial
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.nn.parallel import ColoDDP
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor import ProcessGroup, ColoParameter
|
||||
from colossalai.testing import parameterize
|
||||
|
||||
|
||||
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
||||
|
@ -30,42 +27,11 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
|
|||
assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2)
|
||||
|
||||
|
||||
def check_model_equal(model_a, model_b, allow_empty: bool = False, same_dtype: bool = True):
|
||||
for (na, pa), (nb, pb) in zip(model_a.named_parameters(), model_b.named_parameters()):
|
||||
assert na == nb
|
||||
|
||||
if not allow_empty:
|
||||
assert pa.storage().size() > 0
|
||||
assert pb.storage().size() > 0
|
||||
else:
|
||||
if pa.storage().size() == 0 or pb.storage().size() == 0:
|
||||
continue
|
||||
|
||||
if same_dtype:
|
||||
assert pa.dtype == pb.dtype
|
||||
temp_pb = pb
|
||||
else:
|
||||
temp_pb = pb.to(pa.dtype)
|
||||
|
||||
assert torch.equal(pa, temp_pb), "Parameter '{}' is not equal.\n {} {}".format(na, pa, pb)
|
||||
|
||||
|
||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
pg = ProcessGroup()
|
||||
return ColoDDP(module, process_group=pg)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module,
|
||||
use_chunk: bool = False,
|
||||
use_zero: bool = False,
|
||||
placement_policy: str = 'cuda') -> ZeroDDP:
|
||||
pg = ProcessGroup()
|
||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
return ZeroDDP(module, gemini_manager)
|
||||
|
||||
|
||||
def run_ddp_state_dict():
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -88,44 +54,9 @@ def run_ddp_state_dict():
|
|||
check_state_dict_equal(torch_state_dict, state_dict)
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('use_zero', [False, True])
|
||||
@parameterize('only_rank_0', [False, True])
|
||||
def run_zero_state_dict(use_chunk, placement_policy, use_zero, only_rank_0):
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
torch_model = model_builder().cuda()
|
||||
org_torch_model = copy.deepcopy(torch_model)
|
||||
torch_state_dict = torch_model.state_dict()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = init_ddpv2(model, use_chunk, use_zero, placement_policy)
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
assert param.get_process_group() is not None
|
||||
|
||||
model.load_state_dict(torch_state_dict, strict=False)
|
||||
check_model_equal(model, torch_model, allow_empty=True, same_dtype=False)
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, ColoParameter):
|
||||
assert param.get_process_group() is not None
|
||||
|
||||
pg = ProcessGroup()
|
||||
state_dict = model.state_dict(only_rank_0=only_rank_0)
|
||||
if not only_rank_0 or pg.dp_local_rank() == 0:
|
||||
torch_model.load_state_dict(state_dict, strict=False)
|
||||
check_model_equal(torch_model, org_torch_model, allow_empty=False, same_dtype=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_ddp_state_dict()
|
||||
run_zero_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -1,74 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
|
||||
from colossalai.gemini.stateful_tensor_container import QueueSTContainer, HeapSTContainer
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_stateful_tensor_container():
|
||||
st1 = StatefulTensor(torch.randn(1, device='cuda'))
|
||||
st2 = StatefulTensor(torch.randn(2, device='cuda'))
|
||||
st3 = StatefulTensor(torch.randn(3, device='cuda'))
|
||||
stateful_tensor_list = [st1, st2, st3]
|
||||
step_list = [st1, st2, st3, st3, st2, st1]
|
||||
|
||||
compute_step_dict = dict()
|
||||
compute_step_dict[st1] = [0, 5]
|
||||
compute_step_dict[st2] = [1, 4]
|
||||
compute_step_dict[st3] = [2, 3]
|
||||
|
||||
def run_queue_test():
|
||||
# test queue container
|
||||
queue_container = QueueSTContainer(compute_step_dict, 6)
|
||||
queue_container.create(stateful_tensor_list)
|
||||
|
||||
res_list = []
|
||||
|
||||
for i in range(6):
|
||||
stateful_tensor = step_list[i]
|
||||
stateful_tensor.trans_state(TensorState.COMPUTE)
|
||||
st_out = queue_container.pop()
|
||||
st_out.move_to(torch.device('cpu'))
|
||||
|
||||
res_list.append(st_out.payload.size(0))
|
||||
|
||||
stateful_tensor.move_to(torch.device('cuda'))
|
||||
queue_container.push(stateful_tensor, i)
|
||||
stateful_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
assert res_list == [2, 3, 1, 2, 3, 2]
|
||||
|
||||
run_queue_test()
|
||||
|
||||
def run_heap_test():
|
||||
# test heap container
|
||||
st1.move_to(torch.device('cuda'))
|
||||
st2.move_to(torch.device('cuda'))
|
||||
st3.move_to(torch.device('cuda'))
|
||||
|
||||
heap_container = HeapSTContainer(compute_step_dict, 6)
|
||||
heap_container.create(stateful_tensor_list)
|
||||
|
||||
res_list = []
|
||||
|
||||
for i in range(6):
|
||||
stateful_tensor = step_list[i]
|
||||
stateful_tensor.trans_state(TensorState.COMPUTE)
|
||||
st_out = heap_container.pop()
|
||||
|
||||
if st_out is not None:
|
||||
res_list.append(st_out.payload.size(0))
|
||||
st_out.move_to(torch.device('cpu'))
|
||||
|
||||
stateful_tensor.move_to(torch.device('cuda'))
|
||||
heap_container.push(stateful_tensor, i)
|
||||
stateful_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
assert res_list == [3, 1, 2, 3, 2]
|
||||
|
||||
run_heap_test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_stateful_tensor_container()
|
|
@ -3,7 +3,7 @@ import colossalai
|
|||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from colossalai.gemini.update import ChunkManagerV2
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec
|
||||
|
@ -19,23 +19,17 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
|
|||
def exam_chunk_memory(keep_gathered, pin_memory):
|
||||
pg = ProcessGroup()
|
||||
|
||||
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(
|
||||
keep_gathered, pin_memory))
|
||||
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
|
||||
|
||||
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
|
||||
config = {
|
||||
2: dict(
|
||||
chunk_size=128,
|
||||
keep_gathered=keep_gathered
|
||||
)
|
||||
}
|
||||
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
|
||||
|
||||
chunk_manager = ChunkManagerV2(config, pin_memory=pin_memory)
|
||||
chunk_manager = ChunkManager(config)
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
for p in params:
|
||||
chunk_manager.append_tensor(p, 'param', 2)
|
||||
chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory)
|
||||
chunk_manager.close_all_groups()
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device
|
|||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.update import ChunkV2
|
||||
from colossalai.gemini.chunk import Chunk
|
||||
|
||||
|
||||
def dist_sum(x):
|
||||
|
@ -38,14 +38,12 @@ def check_euqal(param, param_cp):
|
|||
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ColoProcessGroup()
|
||||
my_chunk = ChunkV2(
|
||||
chunk_size=1024,
|
||||
process_group=pg,
|
||||
dtype=torch.float32,
|
||||
init_device=init_device,
|
||||
keep_gathered=keep_gathered,
|
||||
pin_memory=pin_memory
|
||||
)
|
||||
my_chunk = Chunk(chunk_size=1024,
|
||||
process_group=pg,
|
||||
dtype=torch.float32,
|
||||
init_device=init_device,
|
||||
keep_gathered=keep_gathered,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
param_list = []
|
||||
param_cp_list = []
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
from time import time
|
||||
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
|
||||
|
||||
|
||||
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
chunk_manager = model.chunk_manager
|
||||
param_list = [p for p in model.parameters()]
|
||||
chunk_list = chunk_manager.get_chunks(param_list)
|
||||
for chunk in chunk_list:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
|
||||
for (p0, p1) in zip(model.parameters(), torch_model.parameters()):
|
||||
assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
optimizer.zero_grad()
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optimizer.backward(loss)
|
||||
return logits
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
def exam_gpt_fwd_bwd(placement_policy):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
pg = ProcessGroup()
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(pg.dp_local_rank())
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 0:
|
||||
break
|
||||
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
model.backward(loss)
|
||||
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format(
|
||||
torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits)
|
||||
|
||||
check_grad(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_gpt_fwd_bwd()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(1)
|
|
@ -0,0 +1,118 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
from time import time
|
||||
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
|
||||
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
optimizer.zero_grad()
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optimizer.backward(loss)
|
||||
return logits
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
def exam_gpt_fwd_bwd(placement_policy):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
torch_model = model_builder().cuda()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
||||
# debug_print([0], zero_logits, torch_logits)
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
check_param(model, torch_model)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_gpt_fwd_bwd()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt(1)
|
|
@ -8,7 +8,7 @@ import torch.distributed as dist
|
|||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.gemini.update import search_chunk_configuration
|
||||
from colossalai.gemini.chunk import search_chunk_configuration
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup
|
||||
|
@ -35,12 +35,11 @@ def exam_search_chunk_size():
|
|||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
init_1d_row_spec(model, pg_tp)
|
||||
config_dict = search_chunk_configuration(
|
||||
model,
|
||||
search_range_mb=1,
|
||||
search_interval_byte=16,
|
||||
min_chunk_size_mb=0,
|
||||
filter_exlarge_params=True)
|
||||
config_dict = search_chunk_configuration(model,
|
||||
search_range_mb=1,
|
||||
search_interval_byte=16,
|
||||
min_chunk_size_mb=0,
|
||||
filter_exlarge_params=True)
|
||||
|
||||
for key in config_dict:
|
||||
chunk_size = config_dict[key]['chunk_size']
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
def exam_state_dict(placement_policy, keep_gathered):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
torch_model = model_builder()
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model.train()
|
||||
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
def exam_load_state_dict(placement_policy, keep_gathered):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
set_seed(451)
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
torch_dict = torch_model.state_dict()
|
||||
model.load_state_dict(torch_dict, strict=False)
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_state_dict()
|
||||
exam_load_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_ddp(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_ddp(1)
|
|
@ -0,0 +1,97 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
def exam_zero_optim_state_dict(placement_policy, keep_gathered):
|
||||
set_seed(431)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
|
||||
set_seed(451)
|
||||
torch_model = model_builder() # get a different model
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[world_size]['chunk_size'] = 5000
|
||||
config_dict[world_size]['keep_gathered'] = keep_gathered
|
||||
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
model.train()
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 0:
|
||||
break
|
||||
optim.zero_grad()
|
||||
logits = model(input_ids, attn_mask)
|
||||
logits = logits.float()
|
||||
loss = criterion(logits, input_ids)
|
||||
optim.backward(loss)
|
||||
optim.step()
|
||||
|
||||
optim_state_dict = optim.state_dict()
|
||||
optim.load_state_dict(optim_state_dict)
|
||||
new_state = optim.state_dict()['state']
|
||||
org_state = optim_state_dict['state']
|
||||
|
||||
for k, v in org_state.items():
|
||||
w = new_state[k]
|
||||
for n, m in v.items():
|
||||
if isinstance(m, torch.Tensor):
|
||||
o = w[n]
|
||||
if m.device != o.device:
|
||||
o = o.to(m.device)
|
||||
assert torch.equal(m, o)
|
||||
else:
|
||||
assert m == w[n]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_zero_optim_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_optim(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_optim(1)
|
|
@ -1,86 +0,0 @@
|
|||
import torch
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from typing import List
|
||||
from functools import partial
|
||||
from colossalai.gemini import ChunkManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
|
||||
|
||||
def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]):
|
||||
for p, has_tensor in zip(params, has_tensors):
|
||||
if has_tensor:
|
||||
assert p.storage().size() > 0
|
||||
assert p.device.type == 'cuda'
|
||||
else:
|
||||
assert p.storage().size() == 0
|
||||
|
||||
|
||||
# HAS_TENSORS[use_chunk][use_zero]
|
||||
HAS_TENSORS = {
|
||||
True: {
|
||||
True: [[True, True, False], [False, False, True]],
|
||||
False: [[True, True, True], [True, True, True]]
|
||||
},
|
||||
False: {
|
||||
True: [[True, False, True], [False, True, False]],
|
||||
False: [[True, True, True], [True, True, True]]
|
||||
}
|
||||
}
|
||||
|
||||
TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@parameterize('use_zero', [False, True])
|
||||
def run_chunk_zero(use_chunk, use_zero):
|
||||
pg = ColoProcessGroup()
|
||||
rank = pg.rank()
|
||||
if rank == 0:
|
||||
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
|
||||
params = [torch.rand(8, 8) for _ in range(3)]
|
||||
chunk_size = 128 if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
|
||||
chunk_manager.create_group('param')
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
for p in params:
|
||||
chunk_manager.append_tensor(p, 'param')
|
||||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
|
||||
chunks = chunk_manager.get_chunks(params)
|
||||
for chunk in chunks:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
check_has_params(params, [True, True, True])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
|
||||
for chunk in chunks:
|
||||
chunk_manager.release_chunk(chunk)
|
||||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||
for chunk in chunks:
|
||||
chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_chunk_zero()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_chunk_mapping(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_mapping(2)
|
|
@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.gemini import ChunkManager
|
||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
@ -21,20 +21,20 @@ from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, Compute
|
|||
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
|
||||
if p.storage().size() > 0:
|
||||
assert p.dtype == torch.float16
|
||||
assert tensor_shard_equal(tp.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
|
||||
pg.tp_world_size()), f'{tp} vs {p}\n{n}:\n\t{tp.shape} vs {p.shape}'
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
|
||||
def check_grad_equal(model, torch_model, pg: ProcessGroup):
|
||||
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
|
||||
if p.grad is not None:
|
||||
assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
|
||||
pg.tp_local_rank(), pg.tp_world_size()), \
|
||||
f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}'
|
||||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
if key == 'model.lm_head.weight':
|
||||
continue
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \
|
||||
"parameter '{}' has problem.".format(key)
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
|
@ -62,10 +62,8 @@ def init_1d_col_spec(model, pg: ProcessGroup):
|
|||
p.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@parameterize('use_zero', [False, True])
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
||||
def run_gpt(placement_policy, tp_init_spec_func=None):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -89,15 +87,20 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
|||
if tp_init_spec_func:
|
||||
tp_init_spec_func(model, pg)
|
||||
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=use_zero,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
dp_world_size = pg.dp_world_size()
|
||||
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
config_dict[dp_world_size]['chunk_size'] = 5000
|
||||
config_dict[dp_world_size]['keep_gathered'] = False
|
||||
if placement_policy != 'cuda':
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=1)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
|
@ -105,7 +108,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
|||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
print(chunk_manager)
|
||||
check_param_equal(model, torch_model, pg)
|
||||
check_param(model, torch_model, pg)
|
||||
|
||||
model.eval()
|
||||
torch_model.eval()
|
||||
|
@ -115,13 +118,13 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
|||
if i > 2:
|
||||
break
|
||||
input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
|
||||
logits = run_fwd_bwd(model, criterion, optim, input_ids_colo, attn_mask)
|
||||
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo, attn_mask)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
assert tensor_equal(logits, torch_logits)
|
||||
check_grad_equal(model, torch_model, pg)
|
||||
optim.step()
|
||||
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
check_param_equal(model, torch_model, pg)
|
||||
check_param(model, torch_model, pg)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
|
@ -1,100 +0,0 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.gemini import ChunkManager
|
||||
from functools import partial
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
def check_state(s1, s2):
|
||||
for v1, v2 in zip(s1.values(), s2.values()):
|
||||
if isinstance(v1, torch.Tensor):
|
||||
v1 = v1.to(v2.device)
|
||||
assert torch.equal(v1, v2), f'{torch.sum((v1-v2).abs())}'
|
||||
else:
|
||||
assert v1 == v2
|
||||
|
||||
|
||||
def check_load_state_dict(optim, torch_optim):
|
||||
for group, torch_group in zip(optim.optim.param_groups, torch_optim.param_groups):
|
||||
for p, torch_p in zip(group['params'], torch_group['params']):
|
||||
state = optim.optim.state[p]
|
||||
torch_state = torch_optim.state[torch_p]
|
||||
if p.storage().size() == 0:
|
||||
assert len(state) == 0
|
||||
check_state(state, torch_state)
|
||||
|
||||
|
||||
def check_state_dict(state_dict, torch_state_dict):
|
||||
for (k1, s1), (k2, s2) in zip(state_dict['state'].items(), torch_state_dict['state'].items()):
|
||||
assert k1 == k2
|
||||
check_state(s1, s2)
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@parameterize('use_zero', [False, True])
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
@parameterize('only_rank_0', [False, True])
|
||||
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy, only_rank_0):
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model.cuda()
|
||||
torch_model = model_builder().cuda()
|
||||
|
||||
pg = ProcessGroup()
|
||||
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=use_zero,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=1)
|
||||
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
|
||||
for p in torch_model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
|
||||
torch_optim.step()
|
||||
torch_state_dict = torch_optim.state_dict()
|
||||
optim.load_state_dict(torch_state_dict)
|
||||
check_load_state_dict(optim, torch_optim)
|
||||
|
||||
state_dict = optim.state_dict(only_rank_0)
|
||||
if not only_rank_0 or pg.rank() == 0:
|
||||
check_state_dict(state_dict, torch_state_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_zero_optim_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_optim_state_dict(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_optim_state_dict(2)
|
Loading…
Reference in New Issue