mirror of https://github.com/hpcaitech/ColossalAI
[chore] minor fix after rebase
parent
803878b2fd
commit
46037c2ccd
|
@ -39,7 +39,6 @@ from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
|||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
|
@ -653,6 +652,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
model: HybridParallelModule,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
min_scale: int = 1,
|
||||
growth_factor: float = 2.0,
|
||||
|
@ -685,6 +685,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
optimizer=optimizer,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
pg_to_param_list=pg_to_param_list,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
|
@ -1124,7 +1125,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
|
||||
self.logger.info(
|
||||
f"{type(self).__name__}: dp_group {dist.get_process_group_ranks(self.dp_group)} pp_group {dist.get_process_group_ranks(self.pp_group)} tp_group {dist.get_process_group_ranks(self.tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
|
||||
ranks=[0, 1, 2, 3, 4, 5, 6, 7],
|
||||
ranks=[0],
|
||||
)
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
|
|
|
@ -55,6 +55,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
overlap_allgather: bool = False,
|
||||
):
|
||||
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
|
||||
if not force_overlap_comm and (overlap_communication or partition_grad):
|
||||
|
@ -95,6 +96,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
pp_process_group=pp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
pg_to_param_list=pg_param_list,
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -47,7 +47,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||
dtype, precision = torch.float16, "fp16"
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
print(config)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=pp_size,
|
||||
num_microbatches=pp_size,
|
||||
|
|
Loading…
Reference in New Issue