mirror of https://github.com/InternLM/InternLM
add pipeline memory balance
parent
42851be36b
commit
0b1c6c6704
|
@ -81,6 +81,10 @@ hybrid_zero_optimizer = dict(
|
|||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
clip_grad_norm=1.0,
|
||||
# cuda memory balance for activation
|
||||
cuda_memory_balance=False,
|
||||
cuda_memory_balance_amount=1 * 1024, # MB
|
||||
cuda_memory_balance_compensation={},
|
||||
)
|
||||
|
||||
loss = dict(
|
||||
|
@ -140,8 +144,8 @@ pipeline parallel (dict):
|
|||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=8,
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
zero1=-1,
|
||||
pipeline=dict(size=2, interleaved_overlap=True),
|
||||
sequence_parallel=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -4,11 +4,13 @@
|
|||
import math
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from internlm.core.communication import recv_obj_meta, send_obj_meta
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.monitor import send_alert_message
|
||||
|
@ -33,12 +35,34 @@ from internlm.utils.common import get_current_device
|
|||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
from .utils import compute_norm
|
||||
from .utils import compute_norm, find_subset_with_target_sum
|
||||
|
||||
inf = math.inf
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def _find_tensors_with_target_memory(tensors: List[torch.Tensor], target: int) -> List[int]:
|
||||
tensor_mems = [tensor.nelement() * tensor.element_size() for tensor in tensors]
|
||||
approximate_thresholds = [0.01 * i for i in range(1, 100)]
|
||||
result = None
|
||||
|
||||
for approximate_threshold in approximate_thresholds:
|
||||
result = find_subset_with_target_sum(tensor_mems, target * 1024 * 1024, approximate_threshold)
|
||||
if result is not None:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _flatten_and_sync_params(tensors: List[torch.Tensor]) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
flat_tensor = flatten(tensors)
|
||||
flat_tensor = flat_tensor.data.cuda()
|
||||
sync_param(flat_tensor=flat_tensor, tensor_list=tensors)
|
||||
|
||||
return flat_tensor
|
||||
|
||||
|
||||
class BaseOptimizer(Optimizer):
|
||||
"""
|
||||
Base Optimizer.
|
||||
|
@ -133,6 +157,47 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
|
||||
# Cuda memory balance
|
||||
self._enable_memory_balance = zero_cfg.cuda_memory_balance
|
||||
if self._enable_memory_balance:
|
||||
assert gpc.get_world_size(ParallelMode.PIPELINE) > 0, "pipeline parallel size must > 0"
|
||||
assert gpc.get_world_size(ParallelMode.PIPELINE) % 2 == 0, "pipeline parallel size must be even"
|
||||
|
||||
_peer_local_rank = gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
||||
_self_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
self._memory_balance_role = gpc.get_local_rank(ParallelMode.PIPELINE) // (
|
||||
gpc.get_world_size(ParallelMode.PIPELINE) // 2
|
||||
) # 0: sender, 1: receiver
|
||||
self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank]
|
||||
self._fp32_flat_proxy_param_of_current_rank = None
|
||||
self._memory_balance_comm_handle = None
|
||||
|
||||
compensation_conf = {
|
||||
k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v
|
||||
for k, v in zero_cfg.cuda_memory_balance_compensation.items()
|
||||
}
|
||||
_compensation_amount = compensation_conf.get(_self_local_rank, 0) - compensation_conf.get(
|
||||
_peer_local_rank, 0
|
||||
)
|
||||
if self._memory_balance_role == 1:
|
||||
_compensation_amount = -_compensation_amount
|
||||
|
||||
# We balance the memory load caused by different activation quantities on different stages of the pipeline
|
||||
# by having the latter half of the pipeline stages proxy a portion of the optimizer parameters from the
|
||||
# former half. Typically, the first stage's activation occupies pp_size units of memory, decreasing in
|
||||
# increments, and the last stage's activation occupies 1 unit of memory. The number of parameters to be
|
||||
# proxied can be determined based on pp_rank. Since 1 set of optimizer parameters corresponds to 2 sets
|
||||
# of optimizer states, the number of parameters to be proxied needs to be divided by 3. Additionally, the
|
||||
# split parameters are in fp16, while the actual optimizer state parameters are in fp32, so they need to be
|
||||
# divided by 2.
|
||||
self._memory_balance_amount = (
|
||||
(zero_cfg.cuda_memory_balance_amount * abs(_peer_local_rank - _self_local_rank) + _compensation_amount)
|
||||
/ 2
|
||||
/ 3
|
||||
/ 2
|
||||
)
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
initial_scale=initial_scale,
|
||||
|
@ -172,6 +237,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# partition these param groups for data parallel training
|
||||
# and add buffers to parameter store for future access
|
||||
for group_id, param_group in enumerate(self.optim.param_groups):
|
||||
# We only select the parameters to be proxied in the first parameter group.
|
||||
_enable_memory_balance = group_id == 0 and self._enable_memory_balance
|
||||
|
||||
group_params = param_group["params"]
|
||||
|
||||
# add the fp16 params to fp16_param_groups for bookkeeping
|
||||
|
@ -182,14 +250,28 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self.param_group_no_params_ranks.append(no_params_ranks)
|
||||
self.param_group_has_params.append(self._zero_local_rank not in no_params_ranks)
|
||||
|
||||
# split proxy parameters
|
||||
if _enable_memory_balance and self._memory_balance_role == 0:
|
||||
proxy_params_per_rank = [
|
||||
_find_tensors_with_target_memory(params, self._memory_balance_amount) for params in params_per_rank
|
||||
]
|
||||
else:
|
||||
proxy_params_per_rank = [None for _ in params_per_rank]
|
||||
|
||||
# store the mapping between param to rank each param should belong to only one rank
|
||||
for rank, params in enumerate(params_per_rank):
|
||||
for rank in range(self._zero_world_size):
|
||||
params = params_per_rank[rank]
|
||||
proxy_params = proxy_params_per_rank[rank]
|
||||
|
||||
# check whether any rank is not assigned params.
|
||||
if len(params) != 0:
|
||||
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
|
||||
for param in params:
|
||||
setattr(param, "group_id", group_id)
|
||||
self._param_store.set_param_to_rank(param, rank)
|
||||
if len(params) == 0:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
setattr(param, "group_id", group_id)
|
||||
self._param_store.set_param_to_rank(param, rank)
|
||||
|
||||
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params, proxy_params)
|
||||
|
||||
# move to cpu to make room to create the flat tensor
|
||||
for param in group_params:
|
||||
|
@ -198,13 +280,28 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# flatten the reordered tensors
|
||||
for rank in range(self._zero_world_size):
|
||||
# No flat fp16 buffer is allocated if the process has no parameters.
|
||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
|
||||
with torch.no_grad():
|
||||
flat_tensor = flatten(tensor_list)
|
||||
flat_tensor = flat_tensor.data.cuda()
|
||||
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
|
||||
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
|
||||
if rank in self.param_group_no_params_ranks[group_id]:
|
||||
continue
|
||||
|
||||
_params = self._param_store.get_fp16_params_by_rank_group(rank, group_id, option="without_proxy")
|
||||
_flat_tensor = _flatten_and_sync_params(_params)
|
||||
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, _flat_tensor)
|
||||
|
||||
if _enable_memory_balance and self._memory_balance_role == 0:
|
||||
_params = self._param_store.get_fp16_params_by_rank_group(rank, group_id, option="proxy_only")
|
||||
_flat_tensor = _flatten_and_sync_params(_params)
|
||||
self._param_store.add_flat_proxy_param_by_rank_group(rank, group_id, _flat_tensor)
|
||||
|
||||
if _enable_memory_balance and self._memory_balance_role == 0:
|
||||
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(self._zero_local_rank, group_id)
|
||||
|
||||
send_obj_meta(flat_proxy_param, next_rank=self._memory_balance_peer)
|
||||
dist.send(flat_proxy_param, dst=self._memory_balance_peer)
|
||||
elif _enable_memory_balance and self._memory_balance_role == 1:
|
||||
flat_offload_shape = recv_obj_meta(prev_rank=self._memory_balance_peer)
|
||||
|
||||
flat_proxy_param = torch.empty(flat_offload_shape, device=get_current_device(), dtype=self._dtype)
|
||||
dist.recv(flat_proxy_param, src=self._memory_balance_peer)
|
||||
|
||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||
# No flat fp32 buffer is allocated if the process has no parameters.
|
||||
|
@ -221,7 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# need to replace the params in the `params` field in the optimizer
|
||||
# so that when the optimizer calls step(), it only updates the tensors
|
||||
# managed by this data parallel rank
|
||||
param_group["params"] = [fp32_flat_current_rank]
|
||||
optim_params = [fp32_flat_current_rank]
|
||||
|
||||
if _enable_memory_balance and self._memory_balance_role == 1:
|
||||
flat_proxy_param = flat_proxy_param.to(device=device, dtype=fp32_flat_current_rank.dtype)
|
||||
flat_proxy_param.requires_grad = True
|
||||
optim_params.append(flat_proxy_param)
|
||||
self._fp32_flat_proxy_param_of_current_rank = flat_proxy_param
|
||||
|
||||
param_group["params"] = optim_params
|
||||
|
||||
# set reduction state
|
||||
for param in self._fp16_param_groups[group_id]:
|
||||
|
@ -438,23 +543,54 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._param_store.reset_reduced_data_for_compute_norm()
|
||||
|
||||
# accumulate gradient
|
||||
proxy_gradinets = []
|
||||
avg_gradients = self._grad_store._averaged_gradients
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._zero_local_rank, group_id)
|
||||
if self._zero_local_rank in self.param_group_no_params_ranks[group_id]:
|
||||
continue
|
||||
|
||||
if group_id not in avg_gradients:
|
||||
avg_gradients[group_id] = []
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(
|
||||
self._zero_local_rank, group_id, option="without_proxy"
|
||||
)
|
||||
|
||||
param_idx = 0
|
||||
if group_id not in avg_gradients:
|
||||
avg_gradients[group_id] = []
|
||||
|
||||
param_idx = 0
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
if len(avg_gradients[group_id]) == param_idx:
|
||||
avg_gradients[group_id].append(param.grad)
|
||||
else:
|
||||
avg_gradients[group_id][param_idx].add_(param.grad)
|
||||
param_idx += 1
|
||||
|
||||
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(
|
||||
self._zero_local_rank, group_id, option="proxy_only"
|
||||
)
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
if len(avg_gradients[group_id]) == param_idx:
|
||||
avg_gradients[group_id].append(param.grad)
|
||||
else:
|
||||
avg_gradients[group_id][param_idx].add_(param.grad)
|
||||
param_idx += 1
|
||||
assert param.grad is not None, "gradient of proxy parameter is None"
|
||||
proxy_gradinets.append(param.grad)
|
||||
|
||||
# send offload gradients to reciever
|
||||
if self._enable_memory_balance and self._memory_balance_role == 0:
|
||||
flat_proxy_grads = flatten(proxy_gradinets)
|
||||
|
||||
self._memory_balance_comm_handle = dist.isend(flat_proxy_grads, self._memory_balance_peer)
|
||||
# torch.cuda.synchronize()
|
||||
elif self._enable_memory_balance and self._enable_memory_balance == 1:
|
||||
_shape = self._fp32_flat_proxy_param_of_current_rank.shape
|
||||
_device = self._fp32_flat_proxy_param_of_current_rank.device
|
||||
flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype)
|
||||
|
||||
self._memory_balance_comm_handle = dist.irecv(flat_proxy_gradient, self._memory_balance_peer)
|
||||
# torch.cuda.synchronize()
|
||||
self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to(
|
||||
dtype=self._fp32_flat_proxy_param_of_current_rank.dtype
|
||||
)
|
||||
|
||||
# the gradients needed are stored in the avg_gradients buffer
|
||||
# thus, can clear this
|
||||
|
@ -633,9 +769,27 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# For those ranks that are not assigned parameters, we just wait for other ranks
|
||||
# to send them updated their own parameters.
|
||||
if self.has_params:
|
||||
if self._enable_memory_balance:
|
||||
self._memory_balance_comm_handle.wait()
|
||||
|
||||
self.optim.step()
|
||||
# release the fp32 grad
|
||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||
if self._enable_memory_balance and self._memory_balance_role == 1:
|
||||
self._fp32_flat_proxy_param_of_current_rank.grad = None
|
||||
|
||||
# receive proxy params
|
||||
if self._enable_memory_balance and self._memory_balance_role == 0:
|
||||
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(
|
||||
rank=self._zero_local_rank, group_id=0
|
||||
)
|
||||
dist.recv(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
|
||||
# torch.cuda.synchronize()
|
||||
elif self._enable_memory_balance and self._memory_balance_role == 1:
|
||||
flat_proxy_param = self._fp32_flat_proxy_param_of_current_rank.to(dtype=self._dtype)
|
||||
dist.send(flat_proxy_param, self._memory_balance_peer, gpc.get_group(ParallelMode.PIPELINE))
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
# update fp16 partition updated by the current rank
|
||||
for group_id in range(len(self._fp16_param_groups)):
|
||||
if self.param_group_has_params[group_id]:
|
||||
|
@ -677,6 +831,17 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
else:
|
||||
handles.append(handle)
|
||||
|
||||
if group_id == 0 and self._enable_memory_balance and self._memory_balance_role == 0:
|
||||
flat_proxy_param = self._param_store.get_flat_proxy_param_by_rank_group(rank, group_id)
|
||||
handle = dist.broadcast(
|
||||
flat_proxy_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
|
||||
)
|
||||
|
||||
if self._overlap_sync_param:
|
||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||
else:
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
|
|
|
@ -146,7 +146,9 @@ class ParameterStore(BaseStore):
|
|||
# param partitioning data structures
|
||||
self._fp16_param_to_rank = dict()
|
||||
self._rank_groupid_to_fp16_param_list = dict()
|
||||
self._rank_group_id_to_flat_fp16_param = dict()
|
||||
self._rank_groupid_to_flat_fp16_param = dict()
|
||||
self._rank_groupid_to_proxy_param_indexs = dict()
|
||||
self._rank_groupid_to_flat_proxy_param = dict()
|
||||
|
||||
# param reduction data structures
|
||||
self._is_param_reduced = dict()
|
||||
|
@ -192,26 +194,55 @@ class ParameterStore(BaseStore):
|
|||
tensor_rank = self._fp16_param_to_rank[tensor]
|
||||
return tensor_rank == self._local_rank
|
||||
|
||||
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
|
||||
def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list, indexs_to_proxy) -> None:
|
||||
if rank not in self._rank_groupid_to_fp16_param_list:
|
||||
self._rank_groupid_to_fp16_param_list[rank] = dict()
|
||||
self._rank_groupid_to_proxy_param_indexs[rank] = dict()
|
||||
|
||||
if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
|
||||
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
|
||||
self._rank_groupid_to_proxy_param_indexs[rank][group_id] = []
|
||||
|
||||
if indexs_to_proxy is not None:
|
||||
self._rank_groupid_to_proxy_param_indexs[rank][group_id].extend(indexs_to_proxy)
|
||||
|
||||
self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list)
|
||||
|
||||
def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
|
||||
return self._rank_groupid_to_fp16_param_list[rank][group_id]
|
||||
def get_fp16_params_by_rank_group(self, rank, group_id, option: str = "all") -> List[Tensor]:
|
||||
res = []
|
||||
|
||||
if option == "without_proxy":
|
||||
for idx, param in enumerate(self._rank_groupid_to_fp16_param_list[rank][group_id]):
|
||||
if idx in self._rank_groupid_to_proxy_param_indexs[rank][group_id]:
|
||||
continue
|
||||
res.append(param)
|
||||
elif option == "proxy_only":
|
||||
for idx, param in enumerate(self._rank_groupid_to_fp16_param_list[rank][group_id]):
|
||||
if idx not in self._rank_groupid_to_proxy_param_indexs[rank][group_id]:
|
||||
continue
|
||||
res.append(param)
|
||||
else:
|
||||
res = self._rank_groupid_to_fp16_param_list[rank][group_id]
|
||||
|
||||
return res
|
||||
|
||||
def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
|
||||
if rank not in self._rank_group_id_to_flat_fp16_param:
|
||||
self._rank_group_id_to_flat_fp16_param[rank] = dict()
|
||||
if rank not in self._rank_groupid_to_flat_fp16_param:
|
||||
self._rank_groupid_to_flat_fp16_param[rank] = dict()
|
||||
|
||||
self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
|
||||
self._rank_groupid_to_flat_fp16_param[rank][group_id] = tensor
|
||||
|
||||
def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
|
||||
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
|
||||
return self._rank_groupid_to_flat_fp16_param[rank][group_id]
|
||||
|
||||
def add_flat_proxy_param_by_rank_group(self, rank, group_id, tensor) -> None:
|
||||
if rank not in self._rank_groupid_to_flat_proxy_param:
|
||||
self._rank_groupid_to_flat_proxy_param[rank] = dict()
|
||||
|
||||
self._rank_groupid_to_flat_proxy_param[rank][group_id] = tensor
|
||||
|
||||
def get_flat_proxy_param_by_rank_group(self, rank, group_id) -> Tensor:
|
||||
return self._rank_groupid_to_flat_proxy_param[rank][group_id]
|
||||
|
||||
def is_param_reduced(self, tensor):
|
||||
return self._is_param_reduced[tensor]
|
||||
|
|
|
@ -5,7 +5,7 @@ import math
|
|||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -317,6 +317,27 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
|
|||
return total_norm
|
||||
|
||||
|
||||
def find_subset_with_target_sum(nums: List[int], target: int, approximate_threshold: float = 0.0) -> List[int]:
|
||||
indexs = []
|
||||
|
||||
def _inner_helper(start: int, tmpTarget: int, part_idxs: List[int]):
|
||||
if len(indexs) > 0:
|
||||
return
|
||||
|
||||
if len(part_idxs) > 0 and (
|
||||
tmpTarget >= -target * approximate_threshold and tmpTarget <= target * approximate_threshold
|
||||
):
|
||||
indexs.append(part_idxs)
|
||||
elif tmpTarget > 0:
|
||||
for i in range(start, len(nums)):
|
||||
num = nums[i]
|
||||
_inner_helper(start + 1, tmpTarget - num, part_idxs + [i])
|
||||
|
||||
_inner_helper(start=0, tmpTarget=target, part_idxs=[])
|
||||
|
||||
return indexs[0] if len(indexs) > 0 else None
|
||||
|
||||
|
||||
class BaseGradScaler(ABC):
|
||||
"""A base class for the gradient scaler.
|
||||
|
||||
|
|
7
train.py
7
train.py
|
@ -177,9 +177,10 @@ def main(args):
|
|||
memory_profiler = SimpleMemoryProfiler(
|
||||
model,
|
||||
optimizer.optim,
|
||||
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
|
||||
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
|
||||
log_folder=f"memory_trace/rank{gpc.get_global_rank()}"
|
||||
+ f"_dp{gpc.get_local_rank(ParallelMode.DATA)}"
|
||||
+ f"_tp{gpc.get_local_rank(ParallelMode.TENSOR)}"
|
||||
+ f"_pp{gpc.get_local_rank(ParallelMode.PIPELINE)}",
|
||||
)
|
||||
else:
|
||||
memory_profiler = None
|
||||
|
|
Loading…
Reference in New Issue