diff --git a/internlm/model/utils.py b/internlm/model/utils.py index b9c7c03..19531e4 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,12 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Optional, Union +from typing import Optional import fused_dense_lib as fused_dense_cuda import torch import torch.nn.functional as F -from flash_attn.utils.distributed import all_reduce_raw # , reduce_scatter_raw +from flash_attn.utils.distributed import all_reduce_raw from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup @@ -397,7 +397,6 @@ class FSTPFusedDenseFunc(torch.autograd.Function): grad_input = grad_input.contiguous() process_group = ctx.process_group all_gather_handler = ctx.all_gather_handler - module = ctx.module block_index = ctx.block_index module_name = ctx.module_name diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d5fec31..cb8aa65 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,10 +11,7 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import ( - release_reduce_scatter_memory_pool, - split_forward_gather_backward, -) +from internlm.model.utils import release_reduce_scatter_memory_pool from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index 228045e..f486cce 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -45,7 +45,7 @@ class BucketStore(BaseStore): def num_elements_in_bucket(self, reduce_rank: int = None): return self._num_elements_in_bucket[reduce_rank] - + def num_params_in_bucket(self, reduce_rank: int = None): return len(self._params[reduce_rank])