[NFC]fix typo colossalai/auto_parallel nn utils etc. (#3779)

* fix typo colossalai/autochunk auto_parallel amp

* fix typo colossalai/auto_parallel nn utils etc.
pull/3808/head
digger yu 2023-05-23 15:28:20 +08:00 committed by GitHub
parent e871e342b3
commit 9265f2d4d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 46 additions and 46 deletions

View File

@ -6,7 +6,7 @@ from tqdm import tqdm
from .utils import is_rank_0 from .utils import is_rank_0
# Dahaos/rm-static # Dahoas/rm-static
class RmStaticDataset(Dataset): class RmStaticDataset(Dataset):
""" """
Dataset for reward model Dataset for reward model

View File

@ -155,7 +155,7 @@ class EmbeddingModuleHandler(ModuleHandler):
Convert the sharding spec from the logical shape to the physical shape. Convert the sharding spec from the logical shape to the physical shape.
""" """
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D, # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str( input_name=str(
@ -221,7 +221,7 @@ class EmbeddingFunctionHandler(NodeHandler):
Convert the sharding spec from the logical shape to the physical shape. Convert the sharding spec from the logical shape to the physical shape.
""" """
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D, # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str( input_name=str(

View File

@ -23,7 +23,7 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
weight_name: str) -> ShardingStrategy: weight_name: str) -> ShardingStrategy:
""" """
This function is a helper function used by both module node handler and function node handler. This function will This function is a helper function used by both module node handler and function node handler. This function will
convert the sharding spec for the transposed weight to the correct partititon spec. convert the sharding spec for the transposed weight to the correct partition spec.
Args: Args:
strategy (ShardingStrategy): the strategy generated by the strategy generator. strategy (ShardingStrategy): the strategy generated by the strategy generator.
@ -197,7 +197,7 @@ class LinearModuleHandler(MetaInfoModuleHandler):
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight') strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight')
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D, # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input # we need to map the partition at dim 0 to one of the first few dimensions of the input
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_name=str(self.node.args[0]), input_name=str(self.node.args[0]),
@ -267,7 +267,7 @@ class LinearFunctionHandler(MetaInfoNodeHandler):
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
weight_name=str(self.node.args[1])) weight_name=str(self.node.args[1]))
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D, # as input can be multi-dimensional and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input # we need to map the partition at dim 0 to one of the first few dimensions of the input
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_name=str(self.node.args[0]), input_name=str(self.node.args[0]),

View File

@ -48,8 +48,8 @@ def get_matmul_type(input_dim: int, other_dim: int):
Determine which type of matmul operation should be executed for the given tensor dimensions. Determine which type of matmul operation should be executed for the given tensor dimensions.
Args: Args:
input_dim (int): the number of dimensions for the input tenosr input_dim (int): the number of dimensions for the input tensor
other_dim (int): the number of dimensions for the other tenosr other_dim (int): the number of dimensions for the other tensor
""" """
if input_dim == 1 and other_dim == 1: if input_dim == 1 and other_dim == 1:
matmul_type = MatMulType.DOT matmul_type = MatMulType.DOT
@ -268,13 +268,13 @@ class Viewer(BmmTransform):
dim_partition_dict = sharding_spec.dim_partition_dict dim_partition_dict = sharding_spec.dim_partition_dict
entire_shape = sharding_spec.entire_shape entire_shape = sharding_spec.entire_shape
# upddate the dimension index for the matrix dimensions # update the dimension index for the matrix dimensions
if 2 in dim_partition_dict: if 2 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2) dim_partition_dict[len(self.batch_dims_before_view) + 1] = dim_partition_dict.pop(2)
if 1 in dim_partition_dict: if 1 in dim_partition_dict:
dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1) dim_partition_dict[len(self.batch_dims_before_view)] = dim_partition_dict.pop(1)
# map the logical batch dim to phyiscal batch dim # map the logical batch dim to physical batch dim
if 0 in dim_partition_dict: if 0 in dim_partition_dict:
batch_dim_shard = dim_partition_dict.pop(0) batch_dim_shard = dim_partition_dict.pop(0)
dim_partition_dict[physical_batch_dim] = batch_dim_shard dim_partition_dict[physical_batch_dim] = batch_dim_shard
@ -414,7 +414,7 @@ class MatMulHandler(MetaInfoNodeHandler):
def _get_logical_shape_for_mm(self): def _get_logical_shape_for_mm(self):
""" """
We need to handle the input tensor for a matrix-matrix multiplcation as the input We need to handle the input tensor for a matrix-matrix multiplication as the input
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
(e.g. [4] -> [1, 4]). (e.g. [4] -> [1, 4]).
""" """

View File

@ -212,7 +212,7 @@ class NodeHandler(ABC):
return self.strategies_vector return self.strategies_vector
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# tranform the strategy generated # transform the strategy generated
# e.g. to process the sharding strategy for the transposed weights # e.g. to process the sharding strategy for the transposed weights
return strategy return strategy

View File

@ -30,7 +30,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
""" """
if isinstance(input_, Node): if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data' assert hasattr(input_, '_meta_data'), f'The given node has no attribute _meta_data'
meta_tensor = input_._meta_data meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None" assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape shape = meta_tensor.shape

View File

@ -6,12 +6,12 @@ import torch
class PreviousStatus(Enum): class PreviousStatus(Enum):
""" """
This class shows the status of previous comparision. This class shows the status of previous comparison.
""" """
RESET = 0 RESET = 0
# ORIGIN means the dimension size of original tensor is larger in the previous comparision. # ORIGIN means the dimension size of original tensor is larger in the previous comparison.
ORIGIN = 1 ORIGIN = 1
# TGT means the dimension size of target tensor is larger in the previous comparision. # TGT means the dimension size of target tensor is larger in the previous comparison.
TGT = 2 TGT = 2
@ -91,7 +91,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
tgt_index += 1 tgt_index += 1
if previous_label == PreviousStatus.TGT: if previous_label == PreviousStatus.TGT:
# if the target dimension size is larger in the previous comparision, which means # if the target dimension size is larger in the previous comparison, which means
# the origin dimension size has already accumulated larger than target dimension size, so # the origin dimension size has already accumulated larger than target dimension size, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict. # we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
@ -111,7 +111,7 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
origin_index += 1 origin_index += 1
if previous_label == PreviousStatus.ORIGIN: if previous_label == PreviousStatus.ORIGIN:
# if the origin element is larger in the previous comparision, which means # if the origin element is larger in the previous comparison, which means
# the target element has already accumulated larger than origin element, so # the target element has already accumulated larger than origin element, so
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict. # we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
@ -139,7 +139,7 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
Rule: Rule:
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple, For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
the function will return false. the function will return false.
To illustrate this issue, there are two cases to analyse: To illustrate this issue, there are two cases to analyze:
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal 1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
operation without distributed tensor. operation without distributed tensor.
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape 2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape

View File

@ -13,7 +13,7 @@ from .nvme_optimizer import NVMeOptimizer
class CPUAdam(NVMeOptimizer): class CPUAdam(NVMeOptimizer):
"""Implements Adam algorithm. """Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters. Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
But the parameters and gradients should on the same device: But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed. * Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on GPU is allowed.

View File

@ -13,19 +13,19 @@ from .nvme_optimizer import NVMeOptimizer
class HybridAdam(NVMeOptimizer): class HybridAdam(NVMeOptimizer):
"""Implements Adam algorithm. """Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters. Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
But the parameters and gradients should on the same device: But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed. * Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed. * Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed. * Parameters on GPU and gradients on CPU is **not** allowed.
`HybriadAdam` requires CUDA extensions which can be built during installation or runtime. `HybridAdam` requires CUDA extensions which can be built during installation or runtime.
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam. This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.
* For parameters updating on CPU, it uses CPUAdam. * For parameters updating on CPU, it uses CPUAdam.
* For parameters updating on GPU, it uses FusedAdam. * For parameters updating on GPU, it uses FusedAdam.
* Hybird precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients. * Hybrid precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients.
:class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, :class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False`` or ``torch.optim.Adam`` with ``adamw_mode=False``
@ -131,7 +131,7 @@ class HybridAdam(NVMeOptimizer):
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
# record the state by gruop and update at once # record the state by group and update at once
g_l.append(p.grad.data) g_l.append(p.grad.data)
p_l.append(p.data) p_l.append(p.data)
m_l.append(state['exp_avg']) m_l.append(state['exp_avg'])

View File

@ -20,8 +20,8 @@ def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None:
return return
torch.cuda.current_stream().wait_stream(stream) torch.cuda.current_stream().wait_stream(stream)
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
# PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is # PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
# freed, its memory is likely to be reused by newly constructed tenosrs. By default, # freed, its memory is likely to be reused by newly constructed tensors. By default,
# this allocator traces whether a tensor is still in use by only the CUDA stream where it # this allocator traces whether a tensor is still in use by only the CUDA stream where it
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream # was created. When a tensor is used by additional CUDA streams, we need to call record_stream
# to tell the allocator about all these streams. Otherwise, the allocator might free the # to tell the allocator about all these streams. Otherwise, the allocator might free the
@ -294,7 +294,7 @@ class CachedParamMgr(torch.nn.Module):
print( print(
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
) )
print(f'cpu_to_cuda_elpase {elapsed} sec') print(f'cpu_to_cuda_elapse {elapsed} sec')
for k, v in self._elapsed_dict.items(): for k, v in self._elapsed_dict.items():
print(f'{k}: {v}') print(f'{k}: {v}')

View File

@ -324,7 +324,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
norm_type = float(norm_type) norm_type = float(norm_type)
# Parameters can be on CPU or CUDA # Parameters can be on CPU or CUDA
# If parameters are on CPU, disable CUDA kernerls # If parameters are on CPU, disable CUDA kernels
# Calculate norm. # Calculate norm.
if norm_type == inf: if norm_type == inf:

View File

@ -46,7 +46,7 @@ detector.detect()
I have made some comments on the right of the output for your understanding. I have made some comments on the right of the output for your understanding.
Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memery Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly. Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memory Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly.
**The order of print is not equal to the order the tensor creates, but they are really close.** **The order of print is not equal to the order the tensor creates, but they are really close.**
@ -61,7 +61,7 @@ Note that the total `Mem` of all the tensors and parameters is not equal to `Tot
+ mlp.2.bias cuda:0 (32,) True torch.float32 128 B + mlp.2.bias cuda:0 (32,) True torch.float32 128 B
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 27 Detect Location: "test_tensor_detector.py" line 27
Totle GPU Memery Allocated on cuda:0 is 4.5 KB Total GPU Memory Allocated on cuda:0 is 4.5 KB
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
@ -72,7 +72,7 @@ Totle GPU Memery Allocated on cuda:0 is 4.5 KB
+ Tensor cuda:0 (32,) True torch.float32 128 B # output + Tensor cuda:0 (32,) True torch.float32 128 B # output
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 30 Detect Location: "test_tensor_detector.py" line 30
Totle GPU Memery Allocated on cuda:0 is 5.5 KB Total GPU Memory Allocated on cuda:0 is 5.5 KB
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
@ -82,7 +82,7 @@ Totle GPU Memery Allocated on cuda:0 is 5.5 KB
+ Tensor cuda:0 () True torch.float32 4 B # loss + Tensor cuda:0 () True torch.float32 4 B # loss
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 32 Detect Location: "test_tensor_detector.py" line 32
Totle GPU Memery Allocated on cuda:0 is 6.0 KB Total GPU Memory Allocated on cuda:0 is 6.0 KB
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
@ -103,7 +103,7 @@ Totle GPU Memery Allocated on cuda:0 is 6.0 KB
- Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation - Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 34 Detect Location: "test_tensor_detector.py" line 34
Totle GPU Memery Allocated on cuda:0 is 10.0 KB Total GPU Memory Allocated on cuda:0 is 10.0 KB
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
@ -117,7 +117,7 @@ Totle GPU Memery Allocated on cuda:0 is 10.0 KB
+ Tensor cuda:0 (32,) False torch.float32 128 B + Tensor cuda:0 (32,) False torch.float32 128 B
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 36 Detect Location: "test_tensor_detector.py" line 36
Totle GPU Memery Allocated on cuda:0 is 14.0 KB Total GPU Memory Allocated on cuda:0 is 14.0 KB
------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------------------------------------
``` ```

View File

@ -55,7 +55,7 @@ class TensorDetector():
return self.mem_format(memory_size) return self.mem_format(memory_size)
def mem_format(self, real_memory_size): def mem_format(self, real_memory_size):
# format the tensor memory into a reasonal magnitude # format the tensor memory into a reasonable magnitude
if real_memory_size >= 2**30: if real_memory_size >= 2**30:
return str(real_memory_size / (2**30)) + ' GB' return str(real_memory_size / (2**30)) + ' GB'
if real_memory_size >= 2**20: if real_memory_size >= 2**20:
@ -71,7 +71,7 @@ class TensorDetector():
if (not self.include_cpu) and obj.device == torch.device('cpu'): if (not self.include_cpu) and obj.device == torch.device('cpu'):
continue continue
self.detected.append(id(obj)) self.detected.append(id(obj))
# skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch # skip parameters we had added in __init__ when module is an instance of nn.Module for the first epoch
if id(obj) not in self.tensor_info: if id(obj) not in self.tensor_info:
name = type(obj).__name__ name = type(obj).__name__
@ -84,7 +84,7 @@ class TensorDetector():
name = par_name + ' (with grad)' name = par_name + ' (with grad)'
else: else:
# with no grad attached # with no grad attached
# there will be no new paramters created during running # there will be no new parameters created during running
# so it must be in saved_tensor_info # so it must be in saved_tensor_info
continue continue
# we can also marked common tensors as tensor(with grad) # we can also marked common tensors as tensor(with grad)
@ -155,7 +155,7 @@ class TensorDetector():
if device == torch.device('cpu'): if device == torch.device('cpu'):
continue continue
gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device)) gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device))
self.info += f"Totle GPU Memery Allocated on {device} is {gpu_mem_alloc}\n" self.info += f"Total GPU Memory Allocated on {device} is {gpu_mem_alloc}\n"
self.info += LINE self.info += LINE
self.info += '\n\n' self.info += '\n\n'
if self.show_info: if self.show_info:

View File

@ -102,7 +102,7 @@ class ChunkManager:
""" """
if chunk in self.accessed_chunks: if chunk in self.accessed_chunks:
return return
self.__sub_memroy_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == 'cpu': if chunk.device_type == 'cpu':
chunk.shard_move(get_current_device()) chunk.shard_move(get_current_device())
self.__add_accessed_chunk(chunk) self.__add_accessed_chunk(chunk)
@ -114,7 +114,7 @@ class ChunkManager:
if chunk not in self.accessed_chunks: if chunk not in self.accessed_chunks:
return return
if chunk.can_release: if chunk.can_release:
self.__sub_memroy_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
self.__sub_accessed_chunk(chunk) self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
@ -123,7 +123,7 @@ class ChunkManager:
""" """
if not chunk.can_move or chunk.device_type == device.type: if not chunk.can_move or chunk.device_type == device.type:
return return
self.__sub_memroy_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
chunk.shard_move(device, force_copy) chunk.shard_move(device, force_copy)
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
@ -138,7 +138,7 @@ class ChunkManager:
""" """
if not chunk.can_reduce: if not chunk.can_reduce:
return False return False
self.__sub_memroy_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
chunk.reduce() chunk.reduce()
self.__sub_accessed_chunk(chunk) self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
@ -228,11 +228,11 @@ class ChunkManager:
return self.chunk_groups[group_name] return self.chunk_groups[group_name]
def __close_one_chunk(self, chunk: Chunk): def __close_one_chunk(self, chunk: Chunk):
self.__sub_memroy_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
chunk.close_chunk() chunk.close_chunk()
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
def __sub_memroy_usage(self, usage: Dict[str, int]): def __sub_memory_usage(self, usage: Dict[str, int]):
for k, v in usage.items(): for k, v in usage.items():
self.total_mem[k] -= v self.total_mem[k] -= v

View File

@ -85,7 +85,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
Classify the parameters by their dp degree Classify the parameters by their dp degree
Args: Args:
param_order (OrderedParamGenerator): the order of param be visied param_order (OrderedParamGenerator): the order of param be vised
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False. strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False.
Returns: Returns:

View File

@ -59,7 +59,7 @@ class MemStats(object):
time step. time step.
Args: Args:
param_list (List[torch.nn.Parameter]): a list of torch paramters. param_list (List[torch.nn.Parameter]): a list of torch parameters.
""" """
for p in param_list: for p in param_list:
if p not in self._param_step_dict: if p not in self._param_step_dict: