Browse Source

[chore] minor fix after rebase

colossalchat
hxwang 4 months ago committed by Hongxin Liu
parent
commit
46037c2ccd
  1. 5
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  2. 2
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  3. 1
      tests/test_shardformer/test_model/test_shard_deepseek.py

5
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -39,7 +39,6 @@ from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from colossalai.logging import get_dist_logger
from .pp_plugin_base import PipelinePluginBase
@ -653,6 +652,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
@ -685,6 +685,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer=optimizer,
initial_scale=initial_scale,
min_scale=min_scale,
pg_to_param_list=pg_to_param_list,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
@ -1124,7 +1125,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.logger.info(
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0, 1, 2, 3, 4, 5, 6, 7],
ranks=[0],
)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,

2
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -55,6 +55,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
):
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
if not force_overlap_comm and (overlap_communication or partition_grad):
@ -95,6 +96,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
pp_process_group=pp_process_group,
forced_dtype=forced_dtype,
pg_to_param_list=pg_param_list,
overlap_allgather=overlap_allgather,
)

1
tests/test_shardformer/test_model/test_shard_deepseek.py

@ -47,7 +47,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
print(config)
plugin = MoeHybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,

Loading…
Cancel
Save