feat(optimizer/hybrid_zero_optim.py): fix lint error

pull/456/head
huangting4201 2023-10-20 16:22:29 +08:00
parent 3c6925499f
commit eac382ad0a
3 changed files with 4 additions and 8 deletions

View File

@ -1,12 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Optional, Union
from typing import Optional
import fused_dense_lib as fused_dense_cuda
import torch
import torch.nn.functional as F
from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw
from flash_attn.utils.distributed import all_reduce_raw
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
@ -397,7 +397,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
grad_input = grad_input.contiguous()
process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler
module = ctx.module
block_index = ctx.block_index
module_name = ctx.module_name

View File

@ -11,10 +11,7 @@ from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import (
release_reduce_scatter_memory_pool,
split_forward_gather_backward,
)
from internlm.model.utils import release_reduce_scatter_memory_pool
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,

View File

@ -45,7 +45,7 @@ class BucketStore(BaseStore):
def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]
def num_params_in_bucket(self, reduce_rank: int = None):
return len(self._params[reduce_rank])