mirror of https://github.com/InternLM/InternLM
feat(optimizer/hybrid_zero_optim.py): fix lint error
parent
3c6925499f
commit
eac382ad0a
|
@ -1,12 +1,12 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw
|
||||
from flash_attn.utils.distributed import all_reduce_raw
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -397,7 +397,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
all_gather_handler = ctx.all_gather_handler
|
||||
module = ctx.module
|
||||
block_index = ctx.block_index
|
||||
module_name = ctx.module_name
|
||||
|
||||
|
|
|
@ -11,10 +11,7 @@ from torch.optim import Optimizer
|
|||
|
||||
from internlm.core.context import Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import (
|
||||
release_reduce_scatter_memory_pool,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from internlm.model.utils import release_reduce_scatter_memory_pool
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer.store import (
|
||||
BucketStore,
|
||||
|
|
|
@ -45,7 +45,7 @@ class BucketStore(BaseStore):
|
|||
|
||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||
return self._num_elements_in_bucket[reduce_rank]
|
||||
|
||||
|
||||
def num_params_in_bucket(self, reduce_rank: int = None):
|
||||
return len(self._params[reduce_rank])
|
||||
|
||||
|
|
Loading…
Reference in New Issue