[misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/6022/head
Edenzzzz 2024-08-20 10:32:41 +08:00 committed by GitHub
parent f1c3266a94
commit dcc44aab8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 101 additions and 70 deletions

View File

@ -1,4 +1,3 @@
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Union
@ -8,6 +7,8 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.logging import get_dist_logger
SUPPORT_PEFT = False SUPPORT_PEFT = False
try: try:
import peft import peft
@ -81,12 +82,15 @@ class Booster:
plugin, Plugin plugin, Plugin
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin self.plugin = plugin
self.logger = get_dist_logger()
# set accelerator # set accelerator
if self.plugin and self.plugin.control_device(): if self.plugin and self.plugin.control_device():
self.accelerator = None self.accelerator = None
if device is not None: if device is not None:
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") self.logger.warning(
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
)
else: else:
device = device or "cuda" device = device or "cuda"
self.accelerator = Accelerator(device) self.accelerator = Accelerator(device)
@ -94,7 +98,10 @@ class Booster:
# set precision # set precision
if self.plugin and self.plugin.control_precision(): if self.plugin and self.plugin.control_precision():
if mixed_precision is not None: if mixed_precision is not None:
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") self.logger.warning(
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
ranks=[0],
)
self.mixed_precision = None self.mixed_precision = None
elif mixed_precision is None: elif mixed_precision is None:
self.mixed_precision = None self.mixed_precision = None
@ -267,8 +274,9 @@ class Booster:
), "Please provide pretrained directory path if not passing in lora configuration." ), "Please provide pretrained directory path if not passing in lora configuration."
if quantize is True: if quantize is True:
if bnb_quantization_config is not None: if bnb_quantization_config is not None:
warnings.warn( self.logger.warning(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
ranks=[0],
) )
else: else:
bnb_quantization_config = BnbQuantizationConfig( bnb_quantization_config = BnbQuantizationConfig(

View File

@ -1,5 +1,4 @@
import gc import gc
import logging
import os import os
import random import random
from pathlib import Path from pathlib import Path
@ -27,6 +26,7 @@ from colossalai.checkpoint_io.utils import (
) )
from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
""" """
@ -118,7 +119,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(model, GeminiDDP), "Please boost the model before saving!" assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0])
return return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
@ -143,10 +144,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path) save_config_file(model.unwrap(), checkpoint_path)
logging.info( self.logger.info(
f"The model is split into checkpoint shards. " f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}.",
ranks=[0],
) )
def load_sharded_model( def load_sharded_model(
@ -168,7 +170,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -201,10 +203,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info( self.logger.info(
f"The optimizer is going to be split to checkpoint shards. " f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}.",
ranks=[0],
) )
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
@ -214,7 +217,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file): if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file") self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0])
assert isinstance(optimizer, GeminiOptimizer) assert isinstance(optimizer, GeminiOptimizer)
@ -369,9 +372,12 @@ class GeminiPlugin(DPPluginBase):
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu": if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy" assert placement_policy == "static", "NPU only supports static placement policy"
self.logger = get_dist_logger()
if enable_async_reduce and not pin_memory: if enable_async_reduce and not pin_memory:
logging.warning( self.logger.warning(
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.",
ranks=[0],
) )
pin_memory = True pin_memory = True
self.gemini_config = dict( self.gemini_config = dict(

View File

@ -1,6 +1,5 @@
import ctypes import ctypes
import random import random
import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from copy import deepcopy from copy import deepcopy
@ -27,6 +26,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -1023,6 +1023,7 @@ class HybridParallelPlugin(PipelinePluginBase):
inner_ring_size: int = None, inner_ring_size: int = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.logger = get_dist_logger()
assert ( assert (
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
@ -1040,8 +1041,9 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_size > 1 tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1: if sp_size != 1:
warnings.warn( self.logger.warning(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.",
ranks=[0],
) )
self.sp_size = 1 self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@ -1126,7 +1128,12 @@ class HybridParallelPlugin(PipelinePluginBase):
else: else:
raise NotImplementedError() raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn": if sequence_parallelism_mode == "ring_attn":
assert parallel_output, "Ring Attention doesn't support gathering output yet." if not parallel_output:
self.logger.warning(
"parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.",
ranks=[0],
)
parallel_output = True
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
@ -1231,7 +1238,10 @@ class HybridParallelPlugin(PipelinePluginBase):
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_config["partition_grad"] = False zero_config["partition_grad"] = False
zero_stage = 0 zero_stage = 0
@ -1287,9 +1297,10 @@ class HybridParallelPlugin(PipelinePluginBase):
else: else:
is_zero = self.dp_size > 1 is_zero = self.dp_size > 1
if self.dp_size == 1: if self.dp_size == 1:
warnings.warn( self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0." "If you do not intend to use cpu_offload, please consider set zero_stage=0.",
ranks=[0],
) )
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
@ -1332,7 +1343,7 @@ class HybridParallelPlugin(PipelinePluginBase):
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
if return_outputs: if return_outputs:
warnings.warn("return_outputs may lead to significant extra memory consumption.") self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0])
# Create a context for gradient synchronization based on the optimizer type. # Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
@ -1346,10 +1357,8 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
# run with gradients accumulation # run with gradients accumulation
if ( if model.require_grad_sync == False or (
model.require_grad_sync == False isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
or not torch.is_grad_enabled()
): ):
return outputs return outputs
@ -1449,7 +1458,7 @@ class HybridParallelPlugin(PipelinePluginBase):
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1 assert self.pp_size == 1 and self.tp_size == 1
self.lora_enabled = True self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
if bnb_quantization_config is not None: if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config) model = quantize_model(model, bnb_quantization_config)

View File

@ -1,7 +1,5 @@
import enum import enum
import logging
import os import os
import warnings
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -33,6 +31,7 @@ from colossalai.checkpoint_io.utils import (
) )
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
@ -62,9 +61,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__( def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
) -> None:
super().__init__(module) super().__init__(module)
self.dtype = None self.dtype = None
if precision == "fp16": if precision == "fp16":
@ -76,7 +73,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
module = module.to(get_accelerator().get_current_device()) module = module.to(get_accelerator().get_current_device())
self.module = module self.module = module
self.convert_fn = None self.convert_fn = None
if self.dtype is not None and cast_inputs: if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather self.overlap_allgather = overlap_allgather
if overlap_allgather: if overlap_allgather:
@ -140,7 +137,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
""" """
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -177,10 +174,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master(): if self.coordinator.is_master():
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info( self.logger.info(
f"The optimizer is going to be split to checkpoint shards. " f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}.",
ranks=[0],
) )
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
@ -267,7 +265,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return return
from peft import PeftModel from peft import PeftModel
@ -336,7 +334,6 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False, cpu_offload: bool = False,
master_weights: bool = True, master_weights: bool = True,
verbose: bool = False, verbose: bool = False,
cast_inputs: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
@ -363,8 +360,7 @@ class LowLevelZeroPlugin(DPPluginBase):
) )
self.lora_enabled = False self.lora_enabled = False
self.verbose = verbose self.verbose = verbose
self.cast_inputs = cast_inputs self.logger = get_dist_logger()
# set class name with stage, for better error message # set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
@ -400,7 +396,7 @@ class LowLevelZeroPlugin(DPPluginBase):
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
if bnb_quantization_config is not None: if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config) model = quantize_model(model, bnb_quantization_config)
@ -449,8 +445,9 @@ class LowLevelZeroPlugin(DPPluginBase):
origin_param = name2param[origin_key] origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn( self.logger.warning(
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.",
ranks=[0],
) )
elif ( elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
@ -478,10 +475,7 @@ class LowLevelZeroPlugin(DPPluginBase):
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel( model = LowLevelZeroModel(
model, model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
) )
# TODO: Support Galore + ZeRO # TODO: Support Galore + ZeRO
@ -493,7 +487,10 @@ class LowLevelZeroPlugin(DPPluginBase):
optimizer = cast_to_distributed(optimizer) optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_optim_kwargs["partition_grad"] = False zero_optim_kwargs["partition_grad"] = False
zero_stage = 0 zero_stage = 0

View File

@ -1,4 +1,3 @@
import warnings
from collections import defaultdict from collections import defaultdict
from types import MethodType from types import MethodType
from typing import Callable, List, Optional, OrderedDict, Tuple from typing import Callable, List, Optional, OrderedDict, Tuple
@ -26,6 +25,7 @@ from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import cast_to_distributed from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
@ -215,12 +215,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_p2p: bool = True, overlap_p2p: bool = True,
overlap_allgather: bool = False, overlap_allgather: bool = False,
) -> None: ) -> None:
self.logger = get_dist_logger()
if overlap_communication or zero_stage == 2: if overlap_communication or zero_stage == 2:
overlap_communication = False overlap_communication = False
zero_stage = 1 zero_stage = 1
warnings.warn( self.logger.warning(
f"overlap_communication and zero_stage are set to False and 1 because " f"overlap_communication and zero_stage are set to False and 1 because "
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " f"ZeRO-2 or comm overlap cause program hang when some experts are not routed.",
ranks=[0],
) )
assert ( assert (
@ -238,8 +240,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
tp_size > 1 tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1: if sp_size != 1:
warnings.warn( self.logger.warning(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode},"
"will ignore the given sequence parallelism size.",
ranks=[0],
) )
self.sp_size = 1 self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@ -400,8 +404,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
and self.sequence_parallelism_mode == "all_to_all" and self.sequence_parallelism_mode == "all_to_all"
) )
if use_ddp: if use_ddp:
warnings.warn( self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated" f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
ranks=[0],
) )
self.ddp_config["find_unused_parameters"] = True self.ddp_config["find_unused_parameters"] = True
@ -457,9 +462,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
else: else:
if self.dp_size <= 1: if self.dp_size <= 1:
warnings.warn( self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0." "If you do not intend to use cpu_offload, please consider set zero_stage=0.",
ranks=[0],
) )
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = MoeHybridParallelZeroOptimizer( optimizer = MoeHybridParallelZeroOptimizer(

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -21,6 +22,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
""" """

View File

@ -1,6 +1,4 @@
import logging
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
@ -30,6 +28,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
@ -40,6 +39,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
@ -88,7 +88,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
@ -117,7 +117,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
utils.save_config_file(model.unwrap(), checkpoint_path) utils.save_config_file(model.unwrap(), checkpoint_path)
logging.info( self.logger.info(
f"The model is split into checkpoint shards. " f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}."
@ -162,7 +162,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
Path(checkpoint).mkdir(parents=True, exist_ok=True) Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -200,7 +200,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size) index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file) index_file.write_index_file(save_index_file)
logging.info( self.logger.info(
f"The optimizer is going to be split to checkpoint shards. " f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}."
@ -311,6 +311,7 @@ class TorchFSDPPlugin(DPPluginBase):
param_init_fn=param_init_fn, param_init_fn=param_init_fn,
sync_module_states=sync_module_states, sync_module_states=sync_module_states,
) )
self.logger = get_dist_logger()
else: else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@ -349,7 +350,7 @@ class TorchFSDPPlugin(DPPluginBase):
if optimizer is not None: if optimizer is not None:
if len(optimizer.param_groups) > 1: if len(optimizer.param_groups) > 1:
warnings.warn( self.logger.warning(
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
) )
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)

View File

@ -1,7 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch # this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy import copy
import math import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch import torch
@ -136,7 +135,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose self.verbose = verbose
self.param_groups_backup = list() self.param_groups_backup = list()
self.logger = get_dist_logger()
# Mapping from integer id to real/fake param tensor, used for checkpointing. # Mapping from integer id to real/fake param tensor, used for checkpointing.
self.id_to_real_params: Dict[int, Parameter] = dict() self.id_to_real_params: Dict[int, Parameter] = dict()
self.id_to_fake_params: Dict[int, Parameter] = dict() self.id_to_fake_params: Dict[int, Parameter] = dict()
@ -148,9 +147,10 @@ class GeminiOptimizer(OptimizerWrapper):
for name, param in module.named_parameters(): for name, param in module.named_parameters():
if is_ddp_ignored(param): if is_ddp_ignored(param):
if param.requires_grad: if param.requires_grad:
warnings.warn( self.logger.warning(
f"Parameter `{name}` is ignored by DDP but requires gradient! " f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!" "You should handle its optimizer update by yourself!",
ranks=[0],
) )
else: else:
ddp_param_list.append(param) ddp_param_list.append(param)
@ -842,7 +842,9 @@ class GeminiOptimizer(OptimizerWrapper):
*args, *args,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") self.logger.warning(
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
)
class GeminiAdamOptimizer(GeminiOptimizer): class GeminiAdamOptimizer(GeminiOptimizer):