mirror of https://github.com/InternLM/InternLM
fix(optimizer/util.py) change inf defination
parent
754c5aa69a
commit
ad10b8e03f
|
@ -5,11 +5,11 @@ from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import amp_C
|
import amp_C
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from apex.multi_tensor_apply import multi_tensor_applier
|
from apex.multi_tensor_apply import multi_tensor_applier
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch._six import inf
|
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
|
@ -18,6 +18,8 @@ from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.parallel import is_model_parallel_parameter
|
from internlm.utils.parallel import is_model_parallel_parameter
|
||||||
|
|
||||||
|
inf = math.inf
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue