[doc] added documentation to chunk and chunk manager (#1094)

* [doc] added documentation to chunk and chunk manager

* polish code

* polish code

* polish code
pull/1098/head
Frank Lee 2022-06-10 15:33:06 +08:00 committed by GitHub
parent 1f894e033f
commit cb18922c47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 215 additions and 12 deletions

View File

@ -120,7 +120,7 @@ class ColoDDPV2(ColoDDP):
def _setup_grads_ptr(self): def _setup_grads_ptr(self):
for p in self.module.parameters(): for p in self.module.parameters():
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad: if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
p.grad = None p.grad = None
else: else:
p.grad = p.data p.grad = p.data
@ -154,7 +154,7 @@ class ColoDDPV2(ColoDDP):
chunk = self.chunk_manager.get_chunk(p) chunk = self.chunk_manager.get_chunk(p)
reduced = self.chunk_manager.reduce_chunk(chunk) reduced = self.chunk_manager.reduce_chunk(chunk)
self.chunk_manager.release_chunk(chunk) self.chunk_manager.release_chunk(chunk)
if reduced and not chunk.is_free: if reduced and not chunk.is_empty:
self.overflow_counter += chunk.has_inf_or_nan self.overflow_counter += chunk.has_inf_or_nan
self.chunk_manager.move_chunk(chunk, self.grads_device[p]) self.chunk_manager.move_chunk(chunk, self.grads_device[p])
return empty_grad return empty_grad

View File

@ -38,6 +38,16 @@ class ChunkFullError(Exception):
class Chunk: class Chunk:
"""
A chunk is a contiguous memory space which contains multiple tensors.
Args:
chunk_size (int): the number of elements in a chunk
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
"""
def __init__(self, def __init__(self,
chunk_size: int, chunk_size: int,
src_rank: int, src_rank: int,
@ -51,17 +61,34 @@ class Chunk:
self.dtype = dtype self.dtype = dtype
self.device = init_device or get_current_device() self.device = init_device or get_current_device()
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device) self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
# we only keep the chunk in full in the process by which the tensor is owned
if not self.is_src_rank: if not self.is_src_rank:
self.data.storage().resize_(0) self.data.storage().resize_(0)
# each tensor is associated with a TensorInfo to track meta info
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
self.mem = self.size * self.data.element_size() self.mem = self.size * self.data.element_size()
def append(self, tensor: torch.Tensor) -> None: def append(self, tensor: torch.Tensor) -> None:
"""
Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
assert tensor.dtype == self.dtype assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel() new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.size: if new_utilized_size > self.size:
raise ChunkFullError raise ChunkFullError
# set tensor state
tensor_state = TensorState.FREE tensor_state = TensorState.FREE
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if self.is_src_rank: if self.is_src_rank:
self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1)) self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
tensor_state = TensorState.HOLD tensor_state = TensorState.HOLD
@ -72,6 +99,9 @@ class Chunk:
self.utilized_size = new_utilized_size self.utilized_size = new_utilized_size
def release(self) -> None: def release(self) -> None:
"""
Release the memory space on processes which do not own the chunk.
"""
if not self.is_src_rank: if not self.is_src_rank:
self.data.storage().resize_(0) self.data.storage().resize_(0)
self._update_tensors_state(TensorState.FREE) self._update_tensors_state(TensorState.FREE)
@ -86,19 +116,38 @@ class Chunk:
tensor_info.state = next_state tensor_info.state = next_state
def access(self) -> None: def access(self) -> None:
"""
Broadcast the chunk to synchronize the tensors across data parallel processes.
"""
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if not self.is_src_rank: if not self.is_src_rank:
self.data.storage().resize_(self.size) self.data.storage().resize_(self.size)
self.data.data = self.data.to(get_current_device()) self.data.data = self.data.to(get_current_device())
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
# update tensor meta info
self._update_tensors_ptr() self._update_tensors_ptr()
if not self.is_src_rank: if not self.is_src_rank:
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
def move_device(self, device: torch.device) -> None: def move_device(self, device: torch.device) -> None:
"""
Move the chunk to a target device.
Args:
device (torch.device): the target device for data movement.
"""
self.data.data = self.data.to(device) self.data.data = self.data.to(device)
self._update_tensors_ptr() self._update_tensors_ptr()
def reduce(self, is_all_reduce: bool = False) -> None: def reduce(self, is_all_reduce: bool = False) -> None:
"""
Reduce or all-reduce the chunk.
Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self.data.data = self.data.to(get_current_device()) self.data.data = self.data.to(get_current_device())
if is_all_reduce: if is_all_reduce:
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
@ -108,6 +157,13 @@ class Chunk:
self._update_tensors_state(TensorState.HOLD) self._update_tensors_state(TensorState.HOLD)
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE' assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE'
# As the gradient hook can be triggered either before or after post-backward # As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
@ -123,12 +179,22 @@ class Chunk:
self.tensors_info[tensor].state = tensor_state self.tensors_info[tensor].state = tensor_state
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info = self.tensors_info[tensor] tensor_info = self.tensors_info[tensor]
self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1)) self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape) tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
@property @property
def can_release(self) -> bool: def can_release(self) -> bool:
"""
Check whether the chunk can be released.
"""
for tensor_info in self.tensors_info.values(): for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.HOLD: if tensor_info.state != TensorState.HOLD:
return False return False
@ -136,6 +202,9 @@ class Chunk:
@property @property
def can_move_device(self) -> bool: def can_move_device(self) -> bool:
"""
Check whether the chunk can be moved across devices.
"""
for tensor_info in self.tensors_info.values(): for tensor_info in self.tensors_info.values():
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE): if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
return False return False
@ -143,26 +212,38 @@ class Chunk:
@property @property
def can_reduce(self) -> bool: def can_reduce(self) -> bool:
"""
Check whether the chunk can be reduced.
"""
for tensor_info in self.tensors_info.values(): for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.READY_FOR_REDUCE: if tensor_info.state != TensorState.READY_FOR_REDUCE:
return False return False
return True return True
@property @property
def is_free(self) -> bool: def is_empty(self) -> bool:
"""
Check whether the chunk is empty.
"""
return self.data.storage().size() == 0 return self.data.storage().size() == 0
def __repr__(self) -> str: def __repr__(self) -> str:
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_free}, tensor states={[info.state.name for info in self.tensors_info.values()]}' return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
@property @property
def has_inf_or_nan(self) -> bool: def has_inf_or_nan(self) -> bool:
"""
Check if the chunk has inf or nan values.
"""
return torch.isinf(self.data[:self.utilized_size]).any().item() or \ return torch.isinf(self.data[:self.utilized_size]).any().item() or \
torch.isnan(self.data[:self.utilized_size]).any().item() torch.isnan(self.data[:self.utilized_size]).any().item()
def copy_(self, dest_chunk: 'Chunk'): def copy_(self, dest_chunk: 'Chunk'):
assert not self.is_free """
assert not dest_chunk.is_free Copy the data of this chunk to a destination chunk.
"""
assert not self.is_empty
assert not dest_chunk.is_empty
assert self.size == dest_chunk.size assert self.size == dest_chunk.size
assert self.utilized_size == dest_chunk.utilized_size assert self.utilized_size == dest_chunk.utilized_size
self.data.copy_(dest_chunk.data) self.data.copy_(dest_chunk.data)
@ -170,6 +251,9 @@ class Chunk:
@property @property
def device_type(self) -> str: def device_type(self) -> str:
"""
Get the device type of the chunk.
"""
return self.data.device.type return self.data.device.type
def __hash__(self) -> int: def __hash__(self) -> int:
@ -183,6 +267,14 @@ class Chunk:
class ChunkManager: class ChunkManager:
"""
A manager class to manipulate the tensors in chunks.
Args:
chunk_size (int): the size of a chunk.
enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false.
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""
def __init__(self, def __init__(self,
chunk_size: Optional[int], chunk_size: Optional[int],
@ -201,54 +293,89 @@ class ChunkManager:
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None: def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
"""
Append a tensor to a chunk.
Args:
tensor (torch.Tensor): a tensor to append to the chunk.
group_name (str): the name of the chunk group.
"""
assert tensor not in self.tensor_chunk_map assert tensor not in self.tensor_chunk_map
if self.chunk_size is not None and tensor.numel() > self.chunk_size: if self.chunk_size is not None and tensor.numel() > self.chunk_size:
raise ValueError( raise ValueError(
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})') f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
if group_name not in self.chunk_groups: if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque() self.chunk_groups[group_name] = deque()
try: try:
# append the tensor to the last chunk
self.chunk_groups[group_name][-1].append(tensor) self.chunk_groups[group_name][-1].append(tensor)
except (IndexError, ChunkFullError): except (IndexError, ChunkFullError):
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
chunk_size = self.chunk_size or tensor.numel() chunk_size = self.chunk_size or tensor.numel()
src_rank = self._get_next_src_rank(group_name) src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device) chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device)
if self.enable_distributed_storage and self.chunk_size is None: if self.enable_distributed_storage and self.chunk_size is None:
self.rank_load[group_name][src_rank] += chunk_size self.rank_load[group_name][src_rank] += chunk_size
self.chunk_groups[group_name].append(chunk) self.chunk_groups[group_name].append(chunk)
chunk.append(tensor) chunk.append(tensor)
if not chunk.is_free: if not chunk.is_empty:
self.total_mem[chunk.device_type] += chunk.mem self.total_mem[chunk.device_type] += chunk.mem
self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1] self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1]
if not self.enable_distributed_storage: if not self.enable_distributed_storage:
# as distributed storage is not enabled, there is no need to broadcast
# chunks, thus we set these chunks as accessed
self.accessed_chunks.add(self.chunk_groups[group_name][-1]) self.accessed_chunks.add(self.chunk_groups[group_name][-1])
def _get_next_src_rank(self, group_name: str) -> int: def _get_next_src_rank(self, group_name: str) -> int:
if not self.enable_distributed_storage: if not self.enable_distributed_storage:
# the chunk is owned by the current rank if no distributed storage is enabled
return gpc.get_local_rank(ParallelMode.DATA) return gpc.get_local_rank(ParallelMode.DATA)
if self.chunk_size is None: if self.chunk_size is None:
if group_name not in self.rank_load: if group_name not in self.rank_load:
self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64) self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64)
# the process owning the tensor will be the process with the smallest number of elements
src_rank = torch.argmin(self.rank_load[group_name]).item() src_rank = torch.argmin(self.rank_load[group_name]).item()
else: else:
# chunk is owned by processes in a round-robin fashion
chunk_idx = len(self.chunk_groups[group_name]) chunk_idx = len(self.chunk_groups[group_name])
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA) src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
return src_rank return src_rank
def access_chunk(self, chunk: Chunk) -> None: def access_chunk(self, chunk: Chunk) -> None:
"""
Synchronize the chunks via broadcast.
Args:
chunk (Chunk): the chunk to synchronize.
"""
if chunk in self.accessed_chunks: if chunk in self.accessed_chunks:
if chunk.device_type != 'cuda': if chunk.device_type != 'cuda':
self.total_mem[chunk.device_type] -= chunk.mem self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(get_current_device()) chunk.move_device(get_current_device())
self.total_mem[chunk.device_type] += chunk.mem self.total_mem[chunk.device_type] += chunk.mem
return return
if not chunk.is_free: if not chunk.is_empty:
# as tensor is moved to the target device
# the memory consumption of the original device is reduced
self.total_mem[chunk.device_type] -= chunk.mem self.total_mem[chunk.device_type] -= chunk.mem
chunk.access() chunk.access()
self.accessed_chunks.add(chunk) self.accessed_chunks.add(chunk)
self.total_mem[chunk.device_type] += chunk.mem self.total_mem[chunk.device_type] += chunk.mem
def release_chunk(self, chunk: Chunk) -> None: def release_chunk(self, chunk: Chunk) -> None:
"""
Release the memory space of a chunk.
Args:
chunk (Chunk): the chunk to release memory space
"""
if not self.enable_distributed_storage: if not self.enable_distributed_storage:
return return
if chunk not in self.accessed_chunks: if chunk not in self.accessed_chunks:
@ -256,22 +383,44 @@ class ChunkManager:
if chunk.can_release: if chunk.can_release:
chunk.release() chunk.release()
self.accessed_chunks.remove(chunk) self.accessed_chunks.remove(chunk)
if chunk.is_free: if chunk.is_empty:
# update the memory consumption after releasing
self.total_mem[chunk.device_type] -= chunk.mem self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, chunk: Chunk, device: torch.device) -> None: def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
"""
Move the chunk to the target device.
Args:
chunk (Chunk): the chunk to move to target device
device (torch.device): target device
"""
if chunk.data.device == device: if chunk.data.device == device:
return return
if chunk.can_move_device and not chunk.is_free: if chunk.can_move_device and not chunk.is_empty:
self.total_mem[chunk.device_type] -= chunk.mem self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(device) chunk.move_device(device)
self.total_mem[chunk.device_type] += chunk.mem self.total_mem[chunk.device_type] += chunk.mem
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
"""
Transit tensor state according to pre-defined state machine.
Args:
tensor (torch.Tensor): the tensor for state transititon
state (TensorState): next tensor state for transtition
"""
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state) chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk) -> bool: def reduce_chunk(self, chunk: Chunk) -> bool:
"""
Reduce or all reduce the chunk. If enable_distributed_storage is true, all-reduce is used.
Otherwise, this method uses reduce.
Args:
chunk (Chunk): the chunk for reduction.
"""
if not chunk.can_reduce: if not chunk.can_reduce:
return False return False
self.total_mem[chunk.device_type] -= chunk.mem self.total_mem[chunk.device_type] -= chunk.mem
@ -280,16 +429,39 @@ class ChunkManager:
return True return True
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
"""
Copy data to the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data (torch.Tensor): the tensor to be copied to the chunk
"""
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data) chunk.copy_tensor_to_chunk_slice(tensor, data)
def get_chunk(self, tensor: torch.Tensor) -> Chunk: def get_chunk(self, tensor: torch.Tensor) -> Chunk:
"""
Return the chunk owning the tensor.
Args:
tensor (torch.Tensor): a torch tensor object
"""
return self.tensor_chunk_map[tensor] return self.tensor_chunk_map[tensor]
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None: def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
"""
Add tensors to the buffer for lazy release.
Args:
tensors (List[torch.Tensor]): the tensors to be released lazily
"""
self.lazy_release_tensors.extend(tensors) self.lazy_release_tensors.extend(tensors)
def exec_lazy_release(self) -> None: def exec_lazy_release(self) -> None:
"""
Execute release for tensors added to the lazy release buffer.
"""
for chunk in self.get_chunks(self.lazy_release_tensors): for chunk in self.get_chunks(self.lazy_release_tensors):
self.release_chunk(chunk) self.release_chunk(chunk)
self.lazy_release_tensors.clear() self.lazy_release_tensors.clear()
@ -305,6 +477,13 @@ class ChunkManager:
@staticmethod @staticmethod
def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float: def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float:
"""
Calculate the utilization rate of a chunk.
Args:
chunk_size (int): the size of a chunk
params_numel (List[int]): the list of integers representing the number of elements of parameters
"""
assert len(params_numel) > 0 assert len(params_numel) > 0
total_size = 0 total_size = 0
total_utilized_size = 0 total_utilized_size = 0
@ -323,6 +502,17 @@ class ChunkManager:
search_range: int, search_range: int,
n_grids: int, n_grids: int,
min_chunk_size: Optional[int] = None) -> int: min_chunk_size: Optional[int] = None) -> int:
"""
Search for the chunk size for optimal chunk utilization.
Args:
module (torch.nn.Module): a torch module object
search_range (int): the range of chunk size to search. The actual search range will be from
max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + search_range.
n_grids (int): the number of intervals in the search range
min_chunk_size (int): optional, the minimum size for a chunk. The default is None.
"""
assert search_range % n_grids == 0 assert search_range % n_grids == 0
# TODO(ver217): sort params and filter unused ones # TODO(ver217): sort params and filter unused ones
params_numel = [p.numel() for p in module.parameters()] params_numel = [p.numel() for p in module.parameters()]
@ -342,11 +532,24 @@ class ChunkManager:
return best_chunk_size return best_chunk_size
def copy_chunk_group(self, dest_group_name: str, src_group_name: str): def copy_chunk_group(self, dest_group_name: str, src_group_name: str):
"""
Copy chunk data from one group to another group.
Args:
dest_group_name (str): the destination group which receives the copied data
src_group_name (str): the source group which provides the data to copy
"""
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]): for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
if not dest_chunk.is_free: if not dest_chunk.is_empty:
dest_chunk.copy_(src_chunk) dest_chunk.copy_(src_chunk)
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
"""
Get all chunks owning the input tensors.
Args:
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
"""
chunks = [] chunks = []
for tensor in tensors: for tensor in tensors:
chunk = self.get_chunk(tensor) chunk = self.get_chunk(tensor)

View File

@ -64,7 +64,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def _update_params_ptr(self): def _update_params_ptr(self):
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
if not self.module.chunk_manager.get_chunk(p).is_free: if not self.module.chunk_manager.get_chunk(p).is_empty:
p.data = self.fp16_param_to_fp32_param[p] p.data = self.fp16_param_to_fp32_param[p]
else: else:
assert p.grad is None assert p.grad is None