mirror of https://github.com/InternLM/InternLM
feat(optimizer/hybrid_zero_optim.py): fix lint error
parent
3c6925499f
commit
eac382ad0a
|
@ -1,12 +1,12 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import fused_dense_lib as fused_dense_cuda
|
import fused_dense_lib as fused_dense_cuda
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw
|
from flash_attn.utils.distributed import all_reduce_raw
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
@ -397,7 +397,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
grad_input = grad_input.contiguous()
|
grad_input = grad_input.contiguous()
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
all_gather_handler = ctx.all_gather_handler
|
all_gather_handler = ctx.all_gather_handler
|
||||||
module = ctx.module
|
|
||||||
block_index = ctx.block_index
|
block_index = ctx.block_index
|
||||||
module_name = ctx.module_name
|
module_name = ctx.module_name
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,7 @@ from torch.optim import Optimizer
|
||||||
|
|
||||||
from internlm.core.context import Config, ParallelMode
|
from internlm.core.context import Config, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.model.utils import (
|
from internlm.model.utils import release_reduce_scatter_memory_pool
|
||||||
release_reduce_scatter_memory_pool,
|
|
||||||
split_forward_gather_backward,
|
|
||||||
)
|
|
||||||
from internlm.monitor import send_alert_message
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer.store import (
|
from internlm.solver.optimizer.store import (
|
||||||
BucketStore,
|
BucketStore,
|
||||||
|
|
|
@ -45,7 +45,7 @@ class BucketStore(BaseStore):
|
||||||
|
|
||||||
def num_elements_in_bucket(self, reduce_rank: int = None):
|
def num_elements_in_bucket(self, reduce_rank: int = None):
|
||||||
return self._num_elements_in_bucket[reduce_rank]
|
return self._num_elements_in_bucket[reduce_rank]
|
||||||
|
|
||||||
def num_params_in_bucket(self, reduce_rank: int = None):
|
def num_params_in_bucket(self, reduce_rank: int = None):
|
||||||
return len(self._params[reduce_rank])
|
return len(self._params[reduce_rank])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue