[feature] A new ZeRO implementation (#1644)

pull/1673/head
HELSON 2022-10-09 09:18:51 +08:00 committed by GitHub
parent b1be5b88bd
commit b28991dd0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 935 additions and 1537 deletions

View File

@ -1,10 +1,6 @@
from .chunk import TensorInfo, Chunk, TensorState
from .chunk_mgr import ChunkManager
from .chunk import TensorInfo, TensorState
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
from .gemini_mgr import GeminiManager
__all__ = [
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'ChunkManager', 'TensorInfo', 'Chunk',
'TensorState'
]
__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState']

View File

@ -1,316 +0,0 @@
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, List
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))
@dataclass
class TensorInfo:
state: TensorState
offset: int
end: int
class ChunkFullError(Exception):
pass
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
class Chunk:
"""
A chunk is a contiguous memory space which contains multiple tensors.
Args:
chunk_size (int): the number of elements in a chunk
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
"""
def __init__(self,
chunk_size: int,
src_rank: int,
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
force_data_on_cuda: bool = False) -> None:
self.size = chunk_size
self.utilized_size = 0
self.src_rank = src_rank
self.process_group = process_group
self.is_src_rank = process_group.dp_local_rank() == src_rank
self.global_src_rank = process_group.get_ranks_in_dp()[src_rank]
self.dtype = dtype
device = init_device or get_current_device()
if force_data_on_cuda:
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
if device.type == 'cuda':
free_storage(self._cpu_data)
else:
free_storage(self.data)
else:
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
self._cpu_data = None
# we only keep the chunk in full in the process by which the tensor is owned
if not self.is_src_rank:
free_storage(self._payload)
# each tensor is associated with a TensorInfo to track meta info
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
self.mem = self.size * self.data.element_size()
def append(self, tensor: torch.Tensor) -> None:
"""
Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.size:
raise ChunkFullError
# set tensor state
tensor_state = TensorState.FREE
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if self.is_src_rank:
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
tensor_state = TensorState.HOLD
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
else:
tensor.storage().resize_(0)
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.utilized_size = new_utilized_size
def release(self) -> None:
"""
Release the memory space on processes which do not own the chunk.
"""
if not self.is_src_rank:
free_storage(self._payload)
self._update_tensors_state(TensorState.FREE)
def _update_tensors_ptr(self) -> None:
assert type(self._payload) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
tensor_info.state = next_state
def access(self) -> None:
"""
Broadcast the chunk to synchronize the tensors across data parallel processes.
"""
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if not self.is_src_rank:
alloc_storage(self._payload)
self.move_device(get_current_device(), update_ptr=False)
dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
# update tensor meta info
self._update_tensors_ptr()
if not self.is_src_rank:
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
"""
Move the chunk to a target device.
Args:
device (torch.device): the target device for data movement.
"""
if self._payload.device == device:
return
if self._cpu_data is None:
self.data.data = self.data.to(device)
else:
if device.type == 'cuda':
# cpu -> cuda
src = self._cpu_data
dest = self.data
else:
# cuda -> cpu
src = self.data
dest = self._cpu_data
alloc_storage(dest)
dest.copy_(src)
free_storage(src)
if update_ptr:
self._update_tensors_ptr()
def reduce(self, is_all_reduce: bool = False) -> None:
"""
Reduce or all-reduce the chunk.
Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self.move_device(get_current_device(), update_ptr=False)
if is_all_reduce:
dist.all_reduce(self.data, group=self.process_group.dp_process_group())
else:
dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
self._update_tensors_ptr()
self._update_tensors_state(TensorState.HOLD)
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
# print(
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self.tensors_info[tensor].state = tensor_state
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info = self.tensors_info[tensor]
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
@property
def can_release(self) -> bool:
"""
Check whether the chunk can be released.
"""
for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.HOLD:
return False
return True
@property
def can_move_device(self) -> bool:
"""
Check whether the chunk can be moved across devices.
"""
for tensor_info in self.tensors_info.values():
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
return False
return True
@property
def can_reduce(self) -> bool:
"""
Check whether the chunk can be reduced.
"""
for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.READY_FOR_REDUCE:
return False
return True
@property
def is_empty(self) -> bool:
"""
Check whether the chunk is empty.
"""
return is_storage_empty(self._payload)
def __repr__(self) -> str:
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
@property
def has_inf_or_nan(self) -> bool:
"""
Check if the chunk has inf or nan values.
"""
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
torch.isnan(self._payload[:self.utilized_size]).any().item()
def copy_(self, dest_chunk: 'Chunk'):
"""
Copy the data of this chunk to a destination chunk.
"""
assert not self.is_empty
assert not dest_chunk.is_empty
assert self.size == dest_chunk.size
assert self.utilized_size == dest_chunk.utilized_size
self._payload.copy_(dest_chunk._payload)
self._update_tensors_ptr()
@property
def device_type(self) -> str:
"""
Get the device type of the chunk.
"""
return self._payload.device.type
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
@property
def _payload(self) -> torch.Tensor:
if self._cpu_data is None or is_storage_empty(self._cpu_data):
return self.data
return self._cpu_data

View File

@ -0,0 +1,3 @@
from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk
from .manager import ChunkManager
from .search_utils import clasify_params, search_chunk_configuration

View File

@ -1,14 +1,55 @@
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, List
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkFullError, \
free_storage, alloc_storage
class ChunkV2:
class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))
@dataclass
class TensorInfo:
state: TensorState
offset: int
end: int
class ChunkFullError(Exception):
pass
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
class Chunk:
def __init__(self,
chunk_size: int,
@ -19,18 +60,18 @@ class ChunkV2:
pin_memory: bool = False) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
AgChunk is a kind of chunk, which uses all-gather operation to gather the whole chunk.
This kind of chunk is exclusively used for DDP and ZeRO DDP.
Here we use all-gather operation to gather the whole chunk.
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
It is designed to make the full use of communication and PCIE bandwidth.
Args:
chunk_size (int): the number of elements in a chunk
chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized
The default value is None, which is the current GPU
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copy in pinned CPU memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
self.chunk_size = chunk_size
@ -42,7 +83,8 @@ class ChunkV2:
self.pg_rank = dist.get_rank(self.torch_pg)
# the chunk size should be able to be divied by the size of GPU
assert chunk_size % self.pg_size == 0
if not keep_gathered:
assert chunk_size % self.pg_size == 0
self.shard_size = chunk_size // self.pg_size
self.shard_begin = self.shard_size * self.pg_rank
self.shard_end = self.shard_begin + self.shard_size
@ -80,18 +122,15 @@ class ChunkV2:
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.mapping_chunk -> fp32_chunk
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
self.paired_chunk = None
# if the the gradient of this chunk is reduced, the flag is True
# so the flag is False for unused parameters
self.grad_reduced_flag = False
# if this chunk is synchronized with the optimizer, the flag is True
self.optim_sync_flag = True
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False
@property
def memory_usage(self):
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
cpu_memory = 0
@ -112,7 +151,7 @@ class ChunkV2:
return dict(cuda=cuda_memory, cpu=cpu_memory)
@property
def device_type(self):
def device_type(self) -> str:
if self.chunk_temp is not None:
return self.chunk_temp.device.type
else:
@ -123,6 +162,56 @@ class ChunkV2:
else:
return 'cpu'
@property
def payload(self) -> torch.Tensor:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_total
elif self.cuda_shard is not None:
return self.cuda_shard
else:
return self.cpu_shard
@property
def payload_mem(self) -> int:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_mem
else:
return self.shard_mem
@property
def can_move(self) -> bool:
return not self.is_gathered
@property
def can_release(self) -> bool:
if self.keep_gathered:
return False
else:
return self.tensors_state_monitor[TensorState.HOLD] + \
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property
def can_reduce(self):
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values in CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
else:
assert self.cuda_shard is not None # only check in CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
@ -150,7 +239,10 @@ class ChunkV2:
self.utilized_size = new_utilized_size
def close_chunk(self, shard_dev: Optional[torch.device] = None):
"""Close the chunk. Any tensor can't be appended to a closed chunk.
"""Close the chunk. Any tensor can't be appended to a closed chunk later.
Args:
shard_dev: the device where the shard locates
"""
# sanity check
assert self.chunk_temp is not None
@ -163,6 +255,7 @@ class ChunkV2:
if self.chunk_temp.device.type == 'cpu':
self.chunk_total = self.chunk_temp.to(get_current_device())
self.__update_tensors_ptr()
else:
self.chunk_total = self.chunk_temp
self.chunk_temp = None
@ -186,6 +279,12 @@ class ChunkV2:
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):
"""Move the shard tensor in the chunk.
Args:
device: the device to which the shard will move
force_copy: if True, copy function is called mandatorily
"""
# sanity check
assert not self.is_gathered
# when the current chunk is not synchronized with the optimizer
@ -223,8 +322,7 @@ class ChunkV2:
raise NotImplementedError
def access_chunk(self):
"""Make the chunk usable for the parameters inside it.
It is an operation done in CUDA.
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
@ -234,8 +332,7 @@ class ChunkV2:
self.__update_tensors_ptr()
def release_chunk(self):
"""Release the usable chunk.
It is an operation done in CUDA.
"""Release the usable chunk. It's an operation done in CUDA.
"""
# sanity check
assert self.chunk_temp is None
@ -244,8 +341,7 @@ class ChunkV2:
self.__scatter()
def reduce(self):
"""Reduce scatter all the gradients.
It is an operation done in CUDA.
"""Reduce scatter all the gradients. It's an operation done in CUDA.
"""
# sanity check
assert self.is_gathered
@ -267,7 +363,6 @@ class ChunkV2:
free_storage(self.chunk_total)
self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD)
self.grad_reduced_flag = True
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
"""
@ -285,9 +380,6 @@ class ChunkV2:
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
# print(
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
@ -306,46 +398,58 @@ class ChunkV2:
self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten())
tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape)
@property
def can_move(self) -> bool:
return not self.is_gathered
@property
def can_release(self) -> bool:
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload.
"""
if self.keep_gathered:
return False
return self.utilized_size
else:
return self.tensors_state_monitor[TensorState.HOLD] + \
self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors
return self.valid_end
@property
def can_reduce(self):
return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
def init_pair(self, friend_chunk: 'Chunk') -> None:
"""Initialize the paired chunk.
"""
Check if the chunk has inf or nan values in CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
self.paired_chunk = friend_chunk
friend_chunk.paired_chunk = self
else:
assert self.cuda_shard is not None # only check in CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
assert self.paired_chunk is friend_chunk
assert friend_chunk.paired_chunk is self
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def optim_update(self) -> None:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
"""
# sanity check
assert self.paired_chunk is not None
friend_chunk = self.paired_chunk
if self.is_gathered is True:
assert friend_chunk.is_gathered is True
self.chunk_total.copy_(friend_chunk.chunk_total)
self.optim_sync_flag = True
elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda':
self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True
self.cpu_vis_flag = False
else:
# optim_sync_flag is set to False
# see shard_move function for more details
assert friend_chunk.device_type == 'cpu'
assert self.device_type == 'cpu'
self.optim_sync_flag = False
self.cpu_vis_flag = False
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
def __gather(self):
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None
if self.pg_size == 1:
self.chunk_total = self.cuda_shard
else:
alloc_storage(self.chunk_total)
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
alloc_storage(self.chunk_total)
gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
self.cuda_shard = None
self.is_gathered = True
@ -404,9 +508,9 @@ class ChunkV2:
def __eq__(self, __o: object) -> bool:
return self is __o
def __repr__(self, detailed: bool = False):
def __repr__(self, detailed: bool = True):
output = [
"AgChunk Information:\n",
"Chunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype,
self.pg_size),
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
@ -442,6 +546,3 @@ class ChunkV2:
output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st]))
return ''.join(output)
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())

View File

@ -4,23 +4,19 @@ from collections import deque
from colossalai.utils import get_current_device
from colossalai.tensor import ColoTensor
from colossalai.gemini.chunk import ChunkFullError, TensorState
from colossalai.gemini.update import ChunkV2 as Chunk
from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk
class ChunkManagerV2:
class ChunkManager:
"""
A manager class to manipulate the tensors in chunks.
Args:
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
pin_memory (bool): if ture, all chunks have a piece of pinned memory in CPU.
"""
def __init__(self, chunk_configuration: Dict[int, Dict],
init_device: Optional[torch.device] = None,
pin_memory: bool = False) -> None:
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device()
self.size_config: Dict[int, int] = dict()
@ -28,7 +24,6 @@ class ChunkManagerV2:
for k, v in self.kwargs_config.items():
self.size_config[k] = v.pop('chunk_size')
v['init_device'] = self.device
v['pin_memory'] = pin_memory
self.chunk_groups: Dict[str, Deque] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
@ -36,8 +31,14 @@ class ChunkManagerV2:
self.lazy_release_tensors: List[torch.Tensor] = list()
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int) -> None:
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
"""Append a tensor to a chunk.
Args:
tensor: the tensor appended to the chunk
group_type: the data type of the group
config_key: the key of the group's name, usually the size of the dp world
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
@ -66,7 +67,8 @@ class ChunkManagerV2:
chunk_size=chunk_size,
process_group=tensor.process_group,
dtype=tensor.dtype,
**chunk_kwargs
pin_memory=pin_memory,
**chunk_kwargs,
)
chunk_group.append(chunk)
@ -87,6 +89,8 @@ class ChunkManagerV2:
if chunk in self.accessed_chunks:
return
self.__sub_memroy_usage(chunk.memory_usage)
if chunk.device_type == 'cpu':
chunk.shard_move(get_current_device())
chunk.access_chunk()
self.__add_memory_usage(chunk.memory_usage)
self.accessed_chunks.add(chunk)
@ -102,13 +106,13 @@ class ChunkManagerV2:
self.__add_memory_usage(chunk.memory_usage)
self.accessed_chunks.remove(chunk)
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
"""Move the shard of the chunk to the target device.
"""
if not chunk.can_move or chunk.device_type == device.type:
return
self.__sub_memroy_usage(chunk.memory_usage)
chunk.shard_move(device)
chunk.shard_move(device, force_copy)
self.__add_memory_usage(chunk.memory_usage)
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
@ -123,7 +127,7 @@ class ChunkManagerV2:
if not chunk.can_reduce:
return False
self.__sub_memroy_usage(chunk.memory_usage)
chunk.release_chunk()
chunk.reduce()
self.__add_memory_usage(chunk.memory_usage)
return True
@ -165,14 +169,14 @@ class ChunkManagerV2:
self.release_chunk(chunk)
self.lazy_release_tensors.clear()
def __repr__(self) -> str:
msg = ['Chunk Manager Information:\n',
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n']
for group_name, group in self.chunk_groups.items():
msg.append(f'Group {group_name}:\n')
for i, chunk in enumerate(group):
msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
def get_cuda_movable_chunks(self, group_type: str) -> List[Chunk]:
chunk_list = []
for group_name in self.chunk_groups:
if group_type in group_name:
for chunk in self.chunk_groups[group_name]:
if chunk.device_type == 'cuda' and chunk.can_move:
chunk_list.append(chunk)
return chunk_list
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
"""
@ -200,6 +204,17 @@ class ChunkManagerV2:
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
def __repr__(self) -> str:
msg = [
'Chunk Manager Information:\n',
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
]
for group_name, group in self.chunk_groups.items():
msg.append(f'Group {group_name}:\n')
for i, chunk in enumerate(group):
msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
def __get_chunk_group(self, group_name: str) -> Deque:
"""Register a chunk group.
"""
@ -208,8 +223,9 @@ class ChunkManagerV2:
return self.chunk_groups[group_name]
def __close_one_chunk(self, chunk: Chunk):
device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda
self.__sub_memroy_usage(chunk.memory_usage)
chunk.close_chunk(self.device)
chunk.close_chunk(device)
self.__add_memory_usage(chunk.memory_usage)
def __sub_memroy_usage(self, usage: Dict[str, int]):

View File

@ -1,3 +1,4 @@
import math
from typing import Dict, List
import numpy as np
import torch.nn as nn
@ -7,7 +8,7 @@ from colossalai.tensor import ColoParameter
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""Filter those parameters whose size is too large from others.
"""
params_size = [p.numel() for p in model.parameters()]
params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)]
params_size_arr = np.array(params_size)
std = np.std(params_size_arr)
@ -36,6 +37,9 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in model.parameters():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if getattr(param, '_ddp_to_ignore', False):
continue
param_key = param.process_group.dp_world_size()
if param_key not in params_dict:
@ -47,13 +51,13 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
def search_chunk_configuration(
model: nn.Module,
search_range_mb: int,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: int = 32,
filter_exlarge_params: bool = True):
search_range_byte = search_range_mb * 1024**2
min_chunk_size_byte = min_chunk_size_mb * 1024**2
assert search_range_byte % search_interval_byte == 0
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Dict:
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
params_dict = clasify_params(model)
config_dict: Dict[int, Dict] = dict()
@ -75,11 +79,12 @@ def search_chunk_configuration(
max_size = min_chunk_size_byte
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
min_chunk_waste = float('+inf')
best_chunk_size = max_size
best_chunk_size = start_size
for chunk_size in range(max_size, max_size + search_range_byte + 1, search_interval_byte):
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)

View File

@ -1,344 +0,0 @@
import torch
import numpy as np
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup, ColoTensor
from .chunk import Chunk, ChunkFullError, TensorState
class ChunkManager:
"""
A manager class to manipulate the tensors in chunks.
Args:
chunk_size (int): the size of a chunk.
process_group (ColoProcessGroup): process group of the chunk.
enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def __init__(self,
chunk_size: Optional[int],
process_group: ColoProcessGroup,
enable_distributed_storage: bool = False,
init_device: Optional[torch.device] = None) -> None:
assert chunk_size is None or chunk_size > 0
assert isinstance(process_group, ColoProcessGroup)
self.chunk_size = chunk_size
self.process_group = process_group
self.enable_distributed_storage = enable_distributed_storage
self.device = init_device or get_current_device()
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
self.groups_force_data_on_cuda: Dict[str, bool] = {}
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
self.accessed_chunks: Set[Chunk] = set()
self.lazy_release_tensors: List[torch.Tensor] = []
if enable_distributed_storage and chunk_size is None:
self.rank_load: Dict[str, torch.Tensor] = {}
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None:
"""Create a chunk group.
Args:
group_name (str): group name
force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False.
"""
assert group_name not in self.chunk_groups
self.chunk_groups[group_name] = deque()
self.groups_force_data_on_cuda[group_name] = force_data_on_cuda
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
"""
Append a tensor to a chunk.
Args:
tensor (torch.Tensor): a tensor to append to the chunk.
group_name (str): the name of the chunk group.
"""
assert tensor not in self.tensor_chunk_map
if isinstance(tensor, ColoTensor):
assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group(
), f"Chunk Manager can only manage ColoTensor with the same DP process group"
try:
# append the tensor to the last chunk
self.chunk_groups[group_name][-1].append(tensor)
except (IndexError, ChunkFullError):
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
chunk_size = tensor.numel()
else:
chunk_size = self.chunk_size or tensor.numel()
src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size,
src_rank,
self.process_group,
tensor.dtype,
self.device,
force_data_on_cuda=self.groups_force_data_on_cuda[group_name])
if self.enable_distributed_storage and self.chunk_size is None:
self.rank_load[group_name][src_rank] += chunk_size
self.chunk_groups[group_name].append(chunk)
chunk.append(tensor)
if not chunk.is_empty:
self.total_mem[chunk.device_type] += chunk.mem
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
if not self.enable_distributed_storage:
# as distributed storage is not enabled, there is no need to broadcast
# chunks, thus we set these chunks as accessed
self.accessed_chunks.add(self.chunk_groups[group_name][-1])
def _get_next_src_rank(self, group_name: str) -> int:
if not self.enable_distributed_storage:
# the chunk is owned by the current rank if no distributed storage is enabled
return self.process_group.dp_local_rank()
if self.chunk_size is None:
if group_name not in self.rank_load:
self.rank_load[group_name] = torch.zeros(self.process_group.dp_world_size(), dtype=torch.int64)
# the process owning the tensor will be the process with the smallest number of elements
src_rank = torch.argmin(self.rank_load[group_name]).item()
else:
# chunk is owned by processes in a round-robin fashion
chunk_idx = len(self.chunk_groups[group_name])
src_rank = chunk_idx % self.process_group.dp_world_size()
return src_rank
def access_chunk(self, chunk: Chunk) -> None:
"""
Synchronize the chunks via broadcast.
Args:
chunk (Chunk): the chunk to synchronize.
"""
if chunk in self.accessed_chunks:
if chunk.device_type != 'cuda':
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(get_current_device())
self.total_mem[chunk.device_type] += chunk.mem
return
if not chunk.is_empty:
# as tensor is moved to the target device
# the memory consumption of the original device is reduced
self.total_mem[chunk.device_type] -= chunk.mem
chunk.access()
self.accessed_chunks.add(chunk)
self.total_mem[chunk.device_type] += chunk.mem
def release_chunk(self, chunk: Chunk) -> None:
"""
Release the memory space of a chunk.
Args:
chunk (Chunk): the chunk to release memory space
"""
if not self.enable_distributed_storage:
return
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
chunk.release()
self.accessed_chunks.remove(chunk)
if chunk.is_empty:
# update the memory consumption after releasing
self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None:
"""
Move the chunk to the target device.
Args:
chunk (Chunk): the chunk to move to target device
device (torch.device): target device
"""
if chunk.device_type == device.type:
return
if chunk.can_move_device and not chunk.is_empty:
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(device, update_ptr=update_ptr)
self.total_mem[chunk.device_type] += chunk.mem
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
"""
Transit tensor state according to pre-defined state machine.
Args:
tensor (torch.Tensor): the tensor for state transititon
state (TensorState): next tensor state for transtition
"""
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk) -> bool:
"""
Reduce or all reduce the chunk. If enable_distributed_storage is true, all-reduce is used.
Otherwise, this method uses reduce.
Args:
chunk (Chunk): the chunk for reduction.
"""
if not chunk.can_reduce:
return False
self.total_mem[chunk.device_type] -= chunk.mem
chunk.reduce(is_all_reduce=not self.enable_distributed_storage)
self.total_mem[chunk.device_type] += chunk.mem
return True
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
"""
Copy data to the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data (torch.Tensor): the tensor to be copied to the chunk
"""
chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data)
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
"""
Return the chunk owning the tensor.
Args:
tensor (torch.Tensor): a torch tensor object
"""
return self.tensor_chunk_map[tensor]
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
"""
Add tensors to the buffer for lazy release.
Args:
tensors (List[torch.Tensor]): the tensors to be released lazily
"""
self.lazy_release_tensors.extend(tensors)
def exec_lazy_release(self) -> None:
"""
Execute release for tensors added to the lazy release buffer.
"""
for chunk in self.get_chunks(self.lazy_release_tensors):
self.release_chunk(chunk)
self.lazy_release_tensors.clear()
def __repr__(self) -> str:
msg = f'Rank {self.process_group.dp_local_rank()}:\n'
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
for group_name, group in self.chunk_groups.items():
msg += f'Group {group_name}:\n'
for i, chunk in enumerate(group):
msg += f'[{i}] {chunk}\n'
return msg
@staticmethod
def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float:
"""
Calculate the utilization rate of a chunk.
Args:
chunk_size (int): the size of a chunk
params_numel (List[int]): the list of integers representing the number of elements of parameters
"""
assert len(params_numel) > 0
total_size = 0
total_utilized_size = 0
cur_chunk_utilized_size = 0
for size in params_numel:
assert chunk_size >= size
total_utilized_size += size
if total_size == 0 or cur_chunk_utilized_size + size > chunk_size:
total_size += chunk_size
cur_chunk_utilized_size = 0
cur_chunk_utilized_size += size
return total_utilized_size / total_size
@staticmethod
def search_chunk_size(module: torch.nn.Module,
search_range: int,
n_grids: int,
min_chunk_size: Optional[int] = None,
filter_exlarge_params: bool = True) -> int:
"""
Search for the chunk size for optimal chunk utilization.
Args:
module (torch.nn.Module): a torch module object
search_range (int): the range of chunk size to search. The actual search range will be from
max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + search_range.
n_grids (int): the number of intervals in the search range
min_chunk_size (int): optional, the minimum size for a chunk. The default is None.
"""
assert search_range % n_grids == 0
# TODO(ver217): sort params and filter unused ones
params_numel = [p.numel() for p in module.parameters()]
if filter_exlarge_params:
params_numel = _filter_exlarge_params(params_numel)
max_param_numel = max(params_numel)
if min_chunk_size is not None:
assert min_chunk_size >= max_param_numel
else:
min_chunk_size = max_param_numel
step_size = search_range // n_grids
max_chunk_util = -1
best_chunk_size = -1
for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size):
chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel)
if chunk_util > max_chunk_util:
max_chunk_util = chunk_util
best_chunk_size = chunk_size
return best_chunk_size
def copy_chunk_group(self, dest_group_name: str, src_group_name: str):
"""
Copy chunk data from one group to another group.
Args:
dest_group_name (str): the destination group which receives the copied data
src_group_name (str): the source group which provides the data to copy
"""
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
if not dest_chunk.is_empty:
dest_chunk.copy_(src_chunk)
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
"""
Get all chunks owning the input tensors.
Args:
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
"""
chunks = []
for tensor in tensors:
chunk = self.get_chunk(tensor)
if chunk not in chunks:
chunks.append(chunk)
return tuple(chunks)
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
def _filter_exlarge_params(params_numel: List[int]) -> List[int]:
params_numel_arr = np.array(params_numel)
std = np.std(params_numel_arr)
mean = np.mean(params_numel_arr)
upper_limit = mean + 3 * std
return list(filter(lambda x: x <= upper_limit, params_numel))

View File

@ -3,7 +3,7 @@ import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time
from colossalai.gemini import Chunk, ChunkManager
from colossalai.gemini.chunk import Chunk, ChunkManager
from .placement_policy import PlacementPolicyFactory
@ -56,37 +56,44 @@ class GeminiManager:
self._evict_time = 0
self._comp_cuda_demand_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None:
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_type: str) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_name)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_type)
self._layout_time += time() - start
vol, evict_time = self._placement_policy.evict_tensors(hold_cuda_tensor_list,
vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
compute_idx=self._compute_idx)
self._d2h_volume += vol
self._evict_time += evict_time
# move COMPUTE tensors to CUDA
self._h2d_volume += cuda_demand
@functools.lru_cache(maxsize=None)
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_type: str):
start = time()
cuda_demand = 0
for chunk in chunks:
if chunk.device_type == 'cpu' or chunk.is_empty:
cuda_demand += chunk.mem
if chunk.device_type == 'cuda':
if chunk.is_gathered:
pass
else:
cuda_demand += chunk.chunk_mem - chunk.shard_mem
elif chunk.device_type == 'cpu':
cuda_demand += chunk.chunk_mem
else:
raise RuntimeError
self._comp_cuda_demand_time += time() - start
can_evict_chunks = []
for chunk in self._chunk_manager.chunk_groups[group_name]:
if not chunk.is_empty and chunk.device_type == 'cuda' and chunk.can_move_device:
can_evict_chunks.append(chunk)
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks(group_type)
return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:

View File

@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini import ChunkManager
from colossalai.gemini.chunk import ChunkManager
import torch
import time

View File

@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import Type
import functools
from colossalai.gemini import Chunk, ChunkManager
from colossalai.gemini.chunk import Chunk, ChunkManager
class PlacementPolicy(ABC):
@ -19,7 +19,7 @@ class PlacementPolicy(ABC):
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector
@abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> None:
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
raise NotImplementedError
@staticmethod
@ -32,12 +32,12 @@ class CPUPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
volume = 0
start = time()
for chunk in can_evict_chunks:
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
volume += chunk.mem
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
volume += chunk.shard_mem
return volume, time() - start
@ -47,7 +47,7 @@ class CUDAPlacementPolicy(PlacementPolicy):
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
return 0, 0
@staticmethod
@ -59,7 +59,8 @@ class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
_warmup_non_model_data_ratio: float = 0.8
_steady_cuda_cap_ratio: float = 0.9
@ -70,14 +71,14 @@ class AutoPlacementPolicy(PlacementPolicy):
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: List[Tuple[Chunk, ...]] = [],
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs) -> int:
**kwargs) -> Tuple[int, float]:
"""
Evict tensors from CUDA device.
Args:
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted.
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
@ -114,12 +115,12 @@ class AutoPlacementPolicy(PlacementPolicy):
for chunk in to_free_chunks:
if freed_cuda_model_data >= to_free_cuda_model_data:
break
freed_cuda_model_data += chunk.mem
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
freed_cuda_model_data += chunk.shard_mem
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}")
return freed_cuda_model_data, time() - start
@staticmethod
@ -147,7 +148,7 @@ class AutoPlacementPolicy(PlacementPolicy):
class PlacementPolicyFactory:
policies: Dict[str, PlacementPolicy] = {
policies: Dict[str, Type[PlacementPolicy]] = {
'cpu': CPUPlacementPolicy,
'cuda': CUDAPlacementPolicy,
'auto': AutoPlacementPolicy

View File

@ -1,131 +0,0 @@
import queue
import heapq
from abc import ABC, abstractmethod
from typing import Optional, List, Dict
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
def evict_check(st: StatefulTensor) -> bool:
if st.state is not TensorState.COMPUTE and st.device.type == 'cuda':
return True
return False
# Here ST means Stateful Tensor
class BaseSTContainer(ABC):
"""A type of container that store all potential stateful tensors which can be evicted from
CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been
evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA
memory, meaning its state isn't COMPUTE.
This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE.
And it pops stateful tensors in function, `evict_tensors`.
In order to acquire an optimal eviction policy, users may need to offer computation step
index of each stateful tensor. So we can use a heap to maintain all potential evictable
statefule tensors. When poping, we can get the stateful tensor that used furthest in
current computation step.
"""
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
self.compute_step_dict = compute_step_dict
self.total_step = total_step
@abstractmethod
def empty(self) -> bool:
pass
@abstractmethod
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
pass
@abstractmethod
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
pass
@abstractmethod
def pop(self) -> Optional[StatefulTensor]:
pass
class QueueSTContainer(BaseSTContainer):
"""Queue type stateful tensor container. This is used in 'cpu' tensor placement policy.
It pops potential evictable stateful tensors in FIFO.
"""
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
super().__init__(compute_step_dict, total_step)
self.container = None
def empty(self) -> bool:
assert self.container is not None
return self.container.empty()
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
self.container = queue.SimpleQueue()
for stateful_tensor in stateful_tensor_list:
self.container.put(stateful_tensor)
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
self.container.put(stateful_tensor)
def pop(self) -> Optional[StatefulTensor]:
ret = None
while not self.empty():
out_tensor = self.container.get()
if evict_check(out_tensor):
ret = out_tensor
break
return ret
class HeapSTContainer(BaseSTContainer):
"""Heap type stateful tensor container. This is used in 'auto' tensor placement policy.
It pops potential evictable stateful tensors in the order of the distance between current
step and next used step.
"""
def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int):
super().__init__(compute_step_dict, total_step)
self.container = None
def empty(self) -> bool:
assert self.container is not None
return self.container == []
def create(self, stateful_tensor_list: List[StatefulTensor]) -> None:
self.container = []
for stateful_tensor in stateful_tensor_list:
# we want to pop the tensor which has the greatest next_step
# so the weight is next_step multiplied by -1
weight = -self.__get_next_compute_step(stateful_tensor, -1)
self.container.append((weight, stateful_tensor))
heapq.heapify(self.container)
def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None:
# we want to pop the tensor which has the greatest next_step
# so the weight is next_step multiplied by -1
weight = -self.__get_next_compute_step(stateful_tensor, cur_step)
heapq.heappush(self.container, (weight, stateful_tensor))
def pop(self) -> Optional[StatefulTensor]:
ret = None
while not self.empty():
_, out_tensor = heapq.heappop(self.container)
if evict_check(out_tensor):
ret = out_tensor
break
return ret
def __get_next_compute_step(self, stateful_tensor: StatefulTensor, cur_step: int):
# compute the id of next step
# if the tensor is not used in the furture
# next_step is set to the maximum
next_step = self.total_step
step_list = self.compute_step_dict[stateful_tensor]
for step in step_list:
if step > cur_step:
next_step = step
break
return next_step

View File

@ -1,3 +0,0 @@
from .chunkv2 import ChunkV2
from .chunk_mgrv2 import ChunkManagerV2
from .search_utils import clasify_params, search_chunk_configuration

View File

@ -3,16 +3,18 @@ import itertools
import torch.distributed as dist
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.gemini.chunk import TensorState, Chunk
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List, Optional, Set
from colossalai.logging import get_dist_logger
from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .reducer import Reducer
from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
@ -208,29 +210,41 @@ class ZeroDDP(ColoDDP):
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False) -> None:
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = ZeROHookV2(gemini_manager)
self.fp32_params: List[ColoParameter] = []
self.fp32_params: List[ColoTensor] = []
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}
self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True)
self.chunk_manager.create_group('fp32_param')
# TODO: get param order and filter unused params
for p in module.parameters():
assert isinstance(p, ColoParameter)
if getattr(p, '_ddp_to_ignore', False):
p.data = p.half()
continue
fp32_p = p.float().detach()
dp_world_size = p.process_group.dp_world_size()
fp32_data = p.float().data
p.data = p.half()
self.chunk_manager.append_tensor(p, 'fp16_param')
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
self._cast_buffers()
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
self._logger = get_dist_logger()
def forward(self, *args, **kwargs):
@ -248,10 +262,7 @@ class ZeroDDP(ColoDDP):
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
p.grad = None
else:
p.grad = p.data
p.grad = None
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
@ -276,21 +287,22 @@ class ZeroDDP(ColoDDP):
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
chunk = self.chunk_manager.get_chunk(p)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
self.chunk_manager.release_chunk(chunk)
if reduced and not chunk.is_empty:
if reduced:
if chunk.is_gathered:
chunk.chunk_total.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
self.overflow_counter += chunk.has_inf_or_nan
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
@ -311,14 +323,11 @@ class ZeroDDP(ColoDDP):
['bias', 'weight']
"""
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
record_flag = (not only_rank_0) or is_rank_0
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars, record_flag)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
@ -326,7 +335,7 @@ class ZeroDDP(ColoDDP):
destination = hook_result
return destination
def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True):
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
@ -339,30 +348,30 @@ class ZeroDDP(ColoDDP):
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
# save parameters
param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
for chunk in chunk_list:
# record the original device of the chunk
org_chunk_dev_typ = chunk.device_type
self.chunk_manager.access_chunk(chunk)
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor in chunk.get_tensors():
rec_p = torch.empty([0])
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
rec_p = tensor.cpu() # move the whole tensor to CPU mem
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
assert tensor not in param_to_save_data
param_to_save_data[tensor] = rec_p
# release the actual memory of the chunk
self.chunk_manager.release_chunk(chunk)
if not chunk.is_empty and org_chunk_dev_typ == 'cpu':
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
param_to_save_data[tensor] = record_tensor
del temp_chunk
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None:
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
rec_p = param_to_save_data[fp32_p]
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
record_parameter = param_to_save_data[fp32_p]
destination[prefix + name] = record_parameter
# save all buffers
for name, buf in self.named_buffers():
@ -466,40 +475,61 @@ class ZeroDDP(ColoDDP):
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
def load(name, dest_tensor, copy_func):
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
def load(param_name, dest_tensor, copy_func):
state_key = prefix + param_name
if state_key in state_dict:
input_param = state_dict[state_key]
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if input_param.shape != dest_tensor.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(key, input_param.shape,
'the shape in current model is {}.'.format(state_key, input_param.shape,
dest_tensor.shape))
return
try:
with torch.no_grad():
# self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param)
copy_func(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, dest_tensor.size(), input_param.size(),
ex.args))
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
input_param.size(), ex.args))
elif strict:
missing_keys.append(key)
missing_keys.append(state_key)
def load_fp32_p(fp32_p, data):
if fp32_p.storage().size() > 0:
self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, data)
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())
fp32_to_name = dict()
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None:
load(name, fp32_p, partial(load_fp32_p, fp32_p))
self.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
parameter_name = fp32_to_name[tensor]
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
if chunk.is_gathered:
chunk.chunk_total.copy_(temp_chunk)
elif chunk.cuda_shard is not None:
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
else:
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
del temp_chunk
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.optim_update()
for name, buf in persistent_buffers.items():
if buf is not None:

View File

@ -0,0 +1,20 @@
import torch
import torch.distributed as dist
from colossalai.gemini.chunk import Chunk
from colossalai.utils import get_current_device
def get_temp_total_chunk_on_cuda(chunk: Chunk):
if chunk.is_gathered:
return chunk.chunk_total
if chunk.cuda_shard is not None:
shard_temp = chunk.cuda_shard
else:
shard_temp = chunk.cpu_shard.to(get_current_device())
total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device())
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
return total_temp

View File

@ -54,8 +54,8 @@ class ZeROHookV2(ParamOpHook):
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
old_training_phase = self._training_phase
try:
old_training_phase = self._training_phase
self._training_phase = training_phase
yield
finally:

View File

@ -2,17 +2,14 @@ import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from torch.nn import Parameter
from colossalai.nn.parallel.data_parallel import ZeroDDP
from typing import Dict
from typing import Dict, Tuple, Set
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm
from collections import defaultdict, abc as container_abcs
from copy import deepcopy
from itertools import chain
from torch._six import inf
from colossalai.gemini.chunk import Chunk, ChunkManager
class OptimState(Enum):
@ -33,8 +30,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
Args:
optim (Optimizer): An Optimizer instance.
module (ZeroDDP): A ``ZeroDDP`` instance.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
@ -61,11 +58,19 @@ class ZeroOptimizer(ColossalaiOptimizer):
assert isinstance(module, ZeroDDP)
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager = self.gemini_manager.chunk_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
self.optim_state = OptimState.UNSCALED
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
for p, fp32_p in zip(module.parameters(), module.fp32_params):
self.fp16_param_to_fp32_param[p] = fp32_p
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
self.chunk16_set.add(chunk_16)
self.__init__optimizer()
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
@ -75,7 +80,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
self._logger = get_dist_logger()
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
@ -90,16 +95,26 @@ class ZeroOptimizer(ColossalaiOptimizer):
self._register_states = disposable(self._register_states_)
def _update_params_ptr(self):
for group in self.optim.param_groups:
for p in group['params']:
if not self.module.chunk_manager.get_chunk(p).is_empty:
p.data = self.fp16_param_to_fp32_param[p]
else:
assert p.grad is None
def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group['params']:
chunk32 = self.param_to_chunk32[fake_param]
begin, end = self.param_to_range[fake_param]
chunk16 = chunk32.paired_chunk
fake_param.data = chunk16.payload[begin:end]
fake_param.grad = fake_param.data
fake_param.data = chunk32.payload[begin:end]
def _update_fp16_params(self):
self.module.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param')
none_tensor = torch.empty([0])
for group in self.param_groups:
for fake_param in group['params']:
assert fake_param.grad is None
fake_param.data = none_tensor
for chunk16 in self.chunk16_set:
chunk16.optim_update()
def _check_overflow(self):
# clear previous overflow record
@ -128,6 +143,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
self._set_grad_ptr()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
@ -138,45 +154,14 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.zero_grad()
self._update_fp16_params()
return
self._update_params_ptr()
ret = self.optim.step(*args, **kwargs)
self._register_states()
self.zero_grad()
self._update_fp16_params()
return ret
def compute_grad_norm(self, norm_type: float = 2.0) -> float:
norm_type = float(norm_type)
if not self.chunk_manager.enable_distributed_storage:
return compute_grad_norm(self.module.parameters(), norm_type)
non_distributed_params = []
distributed_params = []
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
non_distributed_params.append(p)
else:
distributed_params.append(p)
non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type)
distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)],
device=get_current_device())
if norm_type == inf:
dist.all_reduce(distributed_norm_tensor,
op=dist.ReduceOp.MAX,
group=self.chunk_manager.process_group.dp_process_group())
total_norm = max(non_distributed_norm, distributed_norm_tensor.item())
else:
dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group())
total_norm = non_distributed_norm + distributed_norm_tensor.item()
total_norm = total_norm**(1 / norm_type)
return total_norm
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
total_norm = self.compute_grad_norm(norm_type)
_clip_grad_norm(self.module.parameters(), max_norm, total_norm)
return total_norm
raise NotImplementedError
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
@ -197,24 +182,31 @@ class ZeroOptimizer(ColossalaiOptimizer):
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_params_used_cuda_margin_mem = 0
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
self.chunk_manager.chunk_groups['fp32_param']):
if fp32_param_chunk.is_empty:
continue
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
# stores grad now
self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device())
self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device())
fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem
for p in fp16_param_chunk.get_tensors():
state = self.optim.state[p]
for group in self.param_groups:
for fake_param in group['params']:
chunk32 = self.param_to_chunk32[fake_param]
chunk16 = chunk32.paired_chunk
if chunk32.device_type == 'cuda':
continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(chunk32, get_current_device())
# stores grad now
self.chunk_manager.move_chunk(chunk16, get_current_device())
self.module.set_chunk_grad_device(chunk16, get_current_device())
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
for group in self.param_groups:
for fake_param in group['params']:
chunk32 = self.param_to_chunk32[fake_param]
if chunk32.device_type == 'cuda':
state = self.optim.state[fake_param]
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(get_current_device())
self.module._setup_grads_ptr()
def _register_states_(self):
for group in self.optim.param_groups:
for p in group['params']:
@ -223,110 +215,27 @@ class ZeroOptimizer(ColossalaiOptimizer):
if isinstance(val, torch.Tensor):
self.chunk_manager.add_extern_static_tensor(val)
def state_dict(self, only_rank_0: bool = True):
r"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None.
This saves memory usage.
def __init__optimizer(self):
It contains two entries:
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param]
begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
if not self.chunk_manager.enable_distributed_storage and only_rank_0 and not is_rank_0:
return
optim_state_dict = super().state_dict()
scaler_state_dict = self.grad_scaler.state_dict()
optim_state_dict['scaler'] = scaler_state_dict
if not self.chunk_manager.enable_distributed_storage:
return optim_state_dict
local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0}
if not self.chunk_manager.process_group.has_cpu_groups:
self.chunk_manager.process_group.set_cpu_groups()
output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())]
if only_rank_0:
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
dist.gather_object(local_state,
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
dst=dst_rank,
group=self.chunk_manager.process_group.cpu_dp_process_group())
if not is_rank_0:
return
else:
dist.all_gather_object(output, local_state, group=self.chunk_manager.process_group.cpu_dp_process_group())
for state in output:
optim_state_dict['state'].update(state)
return optim_state_dict
for group in self.optim.param_groups:
fake_params_list = list()
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
for param in group['params']:
chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]:
continue
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
if 'scaler' not in state_dict:
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
else:
self.grad_scaler.load_state_dict(deepcopy(state_dict['scaler']))
fake_param = torch.nn.Parameter(torch.empty([0]))
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
self.param_to_range[fake_param] = range_pair
# Validate the state_dict
groups = self.param_groups
saved_groups = deepcopy(state_dict['param_groups'])
fake_params_list.append(fake_param)
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in groups)))
}
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = self.fp16_param_to_fp32_param[id_map[k]]
if param.storage().size() > 0:
state[param] = cast(param, deepcopy(v))
else:
state[k] = deepcopy(v)
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
def convert_state_dict_to_cpu(state: Dict[str, torch.Tensor]):
return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in state.items()}
group['params'] = fake_params_list

View File

@ -6,11 +6,11 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable
from typing import Callable, Type
import torch.distributed as dist
import os
import random
@ -32,10 +32,9 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module, process_group=pg)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, pg)
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)
@ -51,7 +50,7 @@ class Net(torch.nn.Module):
return self.fc2(self.fc1(x))
def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]):
with ColoInitContext(device=get_current_device()):
model = Net().cuda()
w1 = model.fc1.weight
@ -62,8 +61,14 @@ def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], Col
logits = model(x)
loss = torch.sum(logits)
model.backward(loss)
if ddp_cls is ZeroDDP:
w1s_grad = w1
else:
w1s_grad = w1.grad
w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())]
dist.all_gather(w1_grads, w1.grad)
dist.all_gather(w1_grads, w1s_grad)
assert torch.equal(w1_grads[0], w1_grads[1])
w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())]
dist.all_gather(w2_grads, w2.grad)
@ -74,8 +79,7 @@ def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
set_seed(dist.get_rank())
run_fwd_bwd(ColoDDP, init_ddp)
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False))
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True))
run_fwd_bwd(ZeroDDP, init_ddpv2)
@pytest.mark.dist

View File

@ -8,14 +8,11 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager
from functools import partial
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP, ColoDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ColoDDP
from collections import OrderedDict
from colossalai.tensor import ProcessGroup, ColoParameter
from colossalai.testing import parameterize
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
@ -30,42 +27,11 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2)
def check_model_equal(model_a, model_b, allow_empty: bool = False, same_dtype: bool = True):
for (na, pa), (nb, pb) in zip(model_a.named_parameters(), model_b.named_parameters()):
assert na == nb
if not allow_empty:
assert pa.storage().size() > 0
assert pb.storage().size() > 0
else:
if pa.storage().size() == 0 or pb.storage().size() == 0:
continue
if same_dtype:
assert pa.dtype == pb.dtype
temp_pb = pb
else:
temp_pb = pb.to(pa.dtype)
assert torch.equal(pa, temp_pb), "Parameter '{}' is not equal.\n {} {}".format(na, pa, pb)
def init_ddp(module: torch.nn.Module) -> ColoDDP:
pg = ProcessGroup()
return ColoDDP(module, process_group=pg)
def init_ddpv2(module: torch.nn.Module,
use_chunk: bool = False,
use_zero: bool = False,
placement_policy: str = 'cuda') -> ZeroDDP:
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
return ZeroDDP(module, gemini_manager)
def run_ddp_state_dict():
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -88,44 +54,9 @@ def run_ddp_state_dict():
check_state_dict_equal(torch_state_dict, state_dict)
@parameterize('use_chunk', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('use_zero', [False, True])
@parameterize('only_rank_0', [False, True])
def run_zero_state_dict(use_chunk, placement_policy, use_zero, only_rank_0):
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda()
org_torch_model = copy.deepcopy(torch_model)
torch_state_dict = torch_model.state_dict()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = init_ddpv2(model, use_chunk, use_zero, placement_policy)
for param in model.parameters():
if isinstance(param, ColoParameter):
assert param.get_process_group() is not None
model.load_state_dict(torch_state_dict, strict=False)
check_model_equal(model, torch_model, allow_empty=True, same_dtype=False)
for param in model.parameters():
if isinstance(param, ColoParameter):
assert param.get_process_group() is not None
pg = ProcessGroup()
state_dict = model.state_dict(only_rank_0=only_rank_0)
if not only_rank_0 or pg.dp_local_rank() == 0:
torch_model.load_state_dict(state_dict, strict=False)
check_model_equal(torch_model, org_torch_model, allow_empty=False, same_dtype=True)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_ddp_state_dict()
run_zero_state_dict()
@pytest.mark.dist

View File

@ -1,74 +0,0 @@
import pytest
import torch
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
from colossalai.gemini.stateful_tensor_container import QueueSTContainer, HeapSTContainer
@pytest.mark.dist
def test_stateful_tensor_container():
st1 = StatefulTensor(torch.randn(1, device='cuda'))
st2 = StatefulTensor(torch.randn(2, device='cuda'))
st3 = StatefulTensor(torch.randn(3, device='cuda'))
stateful_tensor_list = [st1, st2, st3]
step_list = [st1, st2, st3, st3, st2, st1]
compute_step_dict = dict()
compute_step_dict[st1] = [0, 5]
compute_step_dict[st2] = [1, 4]
compute_step_dict[st3] = [2, 3]
def run_queue_test():
# test queue container
queue_container = QueueSTContainer(compute_step_dict, 6)
queue_container.create(stateful_tensor_list)
res_list = []
for i in range(6):
stateful_tensor = step_list[i]
stateful_tensor.trans_state(TensorState.COMPUTE)
st_out = queue_container.pop()
st_out.move_to(torch.device('cpu'))
res_list.append(st_out.payload.size(0))
stateful_tensor.move_to(torch.device('cuda'))
queue_container.push(stateful_tensor, i)
stateful_tensor.trans_state(TensorState.HOLD)
assert res_list == [2, 3, 1, 2, 3, 2]
run_queue_test()
def run_heap_test():
# test heap container
st1.move_to(torch.device('cuda'))
st2.move_to(torch.device('cuda'))
st3.move_to(torch.device('cuda'))
heap_container = HeapSTContainer(compute_step_dict, 6)
heap_container.create(stateful_tensor_list)
res_list = []
for i in range(6):
stateful_tensor = step_list[i]
stateful_tensor.trans_state(TensorState.COMPUTE)
st_out = heap_container.pop()
if st_out is not None:
res_list.append(st_out.payload.size(0))
st_out.move_to(torch.device('cpu'))
stateful_tensor.move_to(torch.device('cuda'))
heap_container.push(stateful_tensor, i)
stateful_tensor.trans_state(TensorState.HOLD)
assert res_list == [3, 1, 2, 3, 2]
run_heap_test()
if __name__ == '__main__':
test_stateful_tensor_container()

View File

@ -3,7 +3,7 @@ import colossalai
import pytest
import torch.multiprocessing as mp
from functools import partial
from colossalai.gemini.update import ChunkManagerV2
from colossalai.gemini.chunk import ChunkManager
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec
@ -19,23 +19,17 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
def exam_chunk_memory(keep_gathered, pin_memory):
pg = ProcessGroup()
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(
keep_gathered, pin_memory))
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
config = {
2: dict(
chunk_size=128,
keep_gathered=keep_gathered
)
}
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
chunk_manager = ChunkManagerV2(config, pin_memory=pin_memory)
chunk_manager = ChunkManager(config)
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0
for p in params:
chunk_manager.append_tensor(p, 'param', 2)
chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory)
chunk_manager.close_all_groups()
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]

View File

@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.gemini import TensorState
from colossalai.gemini.update import ChunkV2
from colossalai.gemini.chunk import Chunk
def dist_sum(x):
@ -38,14 +38,12 @@ def check_euqal(param, param_cp):
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
world_size = torch.distributed.get_world_size()
pg = ColoProcessGroup()
my_chunk = ChunkV2(
chunk_size=1024,
process_group=pg,
dtype=torch.float32,
init_device=init_device,
keep_gathered=keep_gathered,
pin_memory=pin_memory
)
my_chunk = Chunk(chunk_size=1024,
process_group=pg,
dtype=torch.float32,
init_device=init_device,
keep_gathered=keep_gathered,
pin_memory=pin_memory)
param_list = []
param_cp_list = []

View File

@ -0,0 +1,109 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
from tests.test_tensor.common_utils import debug_print
from time import time
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
chunk_manager.access_chunk(chunk)
for (p0, p1) in zip(model.parameters(), torch_model.parameters()):
assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
pg = ProcessGroup()
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
model.eval()
torch_model.eval()
set_seed(pg.dp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 0:
break
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
model.backward(loss)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format(
torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits)
check_grad(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)

View File

@ -0,0 +1,118 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from time import time
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
optimizer.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
model.eval()
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(1)

View File

@ -8,7 +8,7 @@ import torch.distributed as dist
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.gemini.update import search_chunk_configuration
from colossalai.gemini.chunk import search_chunk_configuration
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup
@ -35,12 +35,11 @@ def exam_search_chunk_size():
with ColoInitContext(device=get_current_device()):
model = model_builder()
init_1d_row_spec(model, pg_tp)
config_dict = search_chunk_configuration(
model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True)
config_dict = search_chunk_configuration(model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
filter_exlarge_params=True)
for key in config_dict:
chunk_size = config_dict[key]['chunk_size']

View File

@ -0,0 +1,111 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_load_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False)
zero_dict = model.state_dict(only_rank_0=False)
for key, value in torch_dict.items():
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
exam_load_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_ddp(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_ddp(1)

View File

@ -0,0 +1,97 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_zero_optim_state_dict(placement_policy, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters())
optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
set_seed(dist.get_rank() * 3 + 128)
model.train()
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 0:
break
optim.zero_grad()
logits = model(input_ids, attn_mask)
logits = logits.float()
loss = criterion(logits, input_ids)
optim.backward(loss)
optim.step()
optim_state_dict = optim.state_dict()
optim.load_state_dict(optim_state_dict)
new_state = optim.state_dict()['state']
org_state = optim_state_dict['state']
for k, v in org_state.items():
w = new_state[k]
for n, m in v.items():
if isinstance(m, torch.Tensor):
o = w[n]
if m.device != o.device:
o = o.to(m.device)
assert torch.equal(m, o)
else:
assert m == w[n]
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_zero_optim_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_optim(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_optim(1)

View File

@ -1,86 +0,0 @@
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
from typing import List
from functools import partial
from colossalai.gemini import ChunkManager
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port
from colossalai.tensor import ProcessGroup as ColoProcessGroup
def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]):
for p, has_tensor in zip(params, has_tensors):
if has_tensor:
assert p.storage().size() > 0
assert p.device.type == 'cuda'
else:
assert p.storage().size() == 0
# HAS_TENSORS[use_chunk][use_zero]
HAS_TENSORS = {
True: {
True: [[True, True, False], [False, False, True]],
False: [[True, True, True], [True, True, True]]
},
False: {
True: [[True, False, True], [False, True, False]],
False: [[True, True, True], [True, True, True]]
}
}
TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
def run_chunk_zero(use_chunk, use_zero):
pg = ColoProcessGroup()
rank = pg.rank()
if rank == 0:
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
params = [torch.rand(8, 8) for _ in range(3)]
chunk_size = 128 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
chunk_manager.create_group('param')
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0
for p in params:
chunk_manager.append_tensor(p, 'param')
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
chunks = chunk_manager.get_chunks(params)
for chunk in chunks:
chunk_manager.access_chunk(chunk)
check_has_params(params, [True, True, True])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
for chunk in chunks:
chunk_manager.release_chunk(chunk)
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
for chunk in chunks:
chunk_manager.move_chunk(chunk, torch.device('cpu'))
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
assert chunk_manager.total_mem['cuda'] == 0
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_chunk_zero()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_chunk_mapping(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_mapping(2)

View File

@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
@ -21,20 +21,20 @@ from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, Compute
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
def check_param_equal(model, torch_model, pg: ProcessGroup):
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
if p.storage().size() > 0:
assert p.dtype == torch.float16
assert tensor_shard_equal(tp.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
pg.tp_world_size()), f'{tp} vs {p}\n{n}:\n\t{tp.shape} vs {p.shape}'
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
zero_dict = model.state_dict(only_rank_0=False)
torch_dict = torch_model.state_dict()
def check_grad_equal(model, torch_model, pg: ProcessGroup):
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
if p.grad is not None:
assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
pg.tp_local_rank(), pg.tp_world_size()), \
f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}'
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
if key == 'model.lm_head.weight':
continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \
"parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
@ -62,10 +62,8 @@ def init_1d_col_spec(model, pg: ProcessGroup):
p.set_tensor_spec(*spec)
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu'])
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
def run_gpt(placement_policy, tp_init_spec_func=None):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -89,15 +87,20 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
if tp_init_spec_func:
tp_init_spec_func(model, pg)
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy))
dp_world_size = pg.dp_world_size()
config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
init_device = torch.device('cpu')
else:
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=1)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
@ -105,7 +108,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
print(chunk_manager)
check_param_equal(model, torch_model, pg)
check_param(model, torch_model, pg)
model.eval()
torch_model.eval()
@ -115,13 +118,13 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
if i > 2:
break
input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
logits = run_fwd_bwd(model, criterion, optim, input_ids_colo, attn_mask)
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert tensor_equal(logits, torch_logits)
check_grad_equal(model, torch_model, pg)
optim.step()
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
zero_optim.step()
torch_optim.step()
check_param_equal(model, torch_model, pg)
check_param(model, torch_model, pg)
def run_dist(rank, world_size, port):

View File

@ -1,100 +0,0 @@
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager
from functools import partial
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ProcessGroup
def check_state(s1, s2):
for v1, v2 in zip(s1.values(), s2.values()):
if isinstance(v1, torch.Tensor):
v1 = v1.to(v2.device)
assert torch.equal(v1, v2), f'{torch.sum((v1-v2).abs())}'
else:
assert v1 == v2
def check_load_state_dict(optim, torch_optim):
for group, torch_group in zip(optim.optim.param_groups, torch_optim.param_groups):
for p, torch_p in zip(group['params'], torch_group['params']):
state = optim.optim.state[p]
torch_state = torch_optim.state[torch_p]
if p.storage().size() == 0:
assert len(state) == 0
check_state(state, torch_state)
def check_state_dict(state_dict, torch_state_dict):
for (k1, s1), (k2, s2) in zip(state_dict['state'].items(), torch_state_dict['state'].items()):
assert k1 == k2
check_state(s1, s2)
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('only_rank_0', [False, True])
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy, only_rank_0):
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model.cuda()
torch_model = model_builder().cuda()
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
for p in torch_model.parameters():
p.grad = torch.rand_like(p)
torch_optim.step()
torch_state_dict = torch_optim.state_dict()
optim.load_state_dict(torch_state_dict)
check_load_state_dict(optim, torch_optim)
state_dict = optim.state_dict(only_rank_0)
if not only_rank_0 or pg.rank() == 0:
check_state_dict(state_dict, torch_state_dict)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_zero_optim_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_zero_optim_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_optim_state_dict(2)