[misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/6022/head
Edenzzzz 2024-08-20 10:32:41 +08:00 committed by GitHub
parent f1c3266a94
commit dcc44aab8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 101 additions and 70 deletions

View File

@ -1,4 +1,3 @@
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
@ -8,6 +7,8 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.logging import get_dist_logger
SUPPORT_PEFT = False
try:
import peft
@ -81,12 +82,15 @@ class Booster:
plugin, Plugin
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin
self.logger = get_dist_logger()
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
if device is not None:
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
self.logger.warning(
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
)
else:
device = device or "cuda"
self.accelerator = Accelerator(device)
@ -94,7 +98,10 @@ class Booster:
# set precision
if self.plugin and self.plugin.control_precision():
if mixed_precision is not None:
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
self.logger.warning(
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
ranks=[0],
)
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
@ -267,8 +274,9 @@ class Booster:
), "Please provide pretrained directory path if not passing in lora configuration."
if quantize is True:
if bnb_quantization_config is not None:
warnings.warn(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
self.logger.warning(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
ranks=[0],
)
else:
bnb_quantization_config = BnbQuantizationConfig(

View File

@ -1,5 +1,4 @@
import gc
import logging
import os
import random
from pathlib import Path
@ -27,6 +26,7 @@ from colossalai.checkpoint_io.utils import (
)
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
@ -118,7 +119,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0])
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
@ -143,10 +144,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
logging.info(
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
f"index located at {save_index_file}.",
ranks=[0],
)
def load_sharded_model(
@ -168,7 +170,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -201,10 +203,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
self.logger.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
f"index located at {save_index_file}.",
ranks=[0],
)
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
@ -214,7 +217,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0])
assert isinstance(optimizer, GeminiOptimizer)
@ -369,9 +372,12 @@ class GeminiPlugin(DPPluginBase):
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
self.logger = get_dist_logger()
if enable_async_reduce and not pin_memory:
logging.warning(
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set."
self.logger.warning(
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.",
ranks=[0],
)
pin_memory = True
self.gemini_config = dict(

View File

@ -1,6 +1,5 @@
import ctypes
import random
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
@ -27,6 +26,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -1023,6 +1023,7 @@ class HybridParallelPlugin(PipelinePluginBase):
inner_ring_size: int = None,
) -> None:
super().__init__()
self.logger = get_dist_logger()
assert (
dist.get_world_size() % (tp_size * pp_size) == 0
@ -1040,8 +1041,9 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1:
warnings.warn(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
self.logger.warning(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.",
ranks=[0],
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@ -1126,7 +1128,12 @@ class HybridParallelPlugin(PipelinePluginBase):
else:
raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
assert parallel_output, "Ring Attention doesn't support gathering output yet."
if not parallel_output:
self.logger.warning(
"parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.",
ranks=[0],
)
parallel_output = True
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
@ -1231,7 +1238,10 @@ class HybridParallelPlugin(PipelinePluginBase):
optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_config["partition_grad"] = False
zero_stage = 0
@ -1287,9 +1297,10 @@ class HybridParallelPlugin(PipelinePluginBase):
else:
is_zero = self.dp_size > 1
if self.dp_size == 1:
warnings.warn(
self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
"If you do not intend to use cpu_offload, please consider set zero_stage=0.",
ranks=[0],
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
@ -1332,7 +1343,7 @@ class HybridParallelPlugin(PipelinePluginBase):
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
if return_outputs:
warnings.warn("return_outputs may lead to significant extra memory consumption.")
self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0])
# Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
@ -1346,10 +1357,8 @@ class HybridParallelPlugin(PipelinePluginBase):
)
# run with gradients accumulation
if (
model.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
or not torch.is_grad_enabled()
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs
@ -1449,7 +1458,7 @@ class HybridParallelPlugin(PipelinePluginBase):
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)

View File

@ -1,7 +1,5 @@
import enum
import logging
import os
import warnings
from contextlib import nullcontext
from functools import partial
from pathlib import Path
@ -33,6 +31,7 @@ from colossalai.checkpoint_io.utils import (
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter
@ -62,9 +61,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
) -> None:
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
@ -76,7 +73,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None and cast_inputs:
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
if overlap_allgather:
@ -140,7 +137,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -177,10 +174,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(
self.logger.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
f"index located at {save_index_file}.",
ranks=[0],
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
@ -267,7 +265,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return
from peft import PeftModel
@ -336,7 +334,6 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
cast_inputs: bool = True,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
@ -363,8 +360,7 @@ class LowLevelZeroPlugin(DPPluginBase):
)
self.lora_enabled = False
self.verbose = verbose
self.cast_inputs = cast_inputs
self.logger = get_dist_logger()
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
@ -400,7 +396,7 @@ class LowLevelZeroPlugin(DPPluginBase):
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
@ -449,8 +445,9 @@ class LowLevelZeroPlugin(DPPluginBase):
origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn(
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
self.logger.warning(
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.",
ranks=[0],
)
elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
@ -478,10 +475,7 @@ class LowLevelZeroPlugin(DPPluginBase):
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(
model,
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
)
# TODO: Support Galore + ZeRO
@ -493,7 +487,10 @@ class LowLevelZeroPlugin(DPPluginBase):
optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_optim_kwargs["partition_grad"] = False
zero_stage = 0

View File

@ -1,4 +1,3 @@
import warnings
from collections import defaultdict
from types import MethodType
from typing import Callable, List, Optional, OrderedDict, Tuple
@ -26,6 +25,7 @@ from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
@ -215,12 +215,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_p2p: bool = True,
overlap_allgather: bool = False,
) -> None:
self.logger = get_dist_logger()
if overlap_communication or zero_stage == 2:
overlap_communication = False
zero_stage = 1
warnings.warn(
self.logger.warning(
f"overlap_communication and zero_stage are set to False and 1 because "
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. "
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed.",
ranks=[0],
)
assert (
@ -238,8 +240,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1:
warnings.warn(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
self.logger.warning(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode},"
"will ignore the given sequence parallelism size.",
ranks=[0],
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@ -400,8 +404,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
and self.sequence_parallelism_mode == "all_to_all"
)
if use_ddp:
warnings.warn(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
ranks=[0],
)
self.ddp_config["find_unused_parameters"] = True
@ -457,9 +462,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
else:
if self.dp_size <= 1:
warnings.warn(
self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
"If you do not intend to use cpu_offload, please consider set zero_stage=0.",
ranks=[0],
)
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = MoeHybridParallelZeroOptimizer(

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
@ -21,6 +22,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
"""

View File

@ -1,6 +1,4 @@
import logging
import os
import warnings
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
@ -30,6 +28,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from .dp_plugin_base import DPPluginBase
@ -40,6 +39,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
@ -88,7 +88,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
@ -117,7 +117,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
utils.save_config_file(model.unwrap(), checkpoint_path)
logging.info(
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
@ -162,7 +162,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
@ -200,7 +200,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
self.logger.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
@ -311,6 +311,7 @@ class TorchFSDPPlugin(DPPluginBase):
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
self.logger = get_dist_logger()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
@ -349,7 +350,7 @@ class TorchFSDPPlugin(DPPluginBase):
if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
self.logger.warning(
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)

View File

@ -1,7 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch
@ -136,7 +135,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose
self.param_groups_backup = list()
self.logger = get_dist_logger()
# Mapping from integer id to real/fake param tensor, used for checkpointing.
self.id_to_real_params: Dict[int, Parameter] = dict()
self.id_to_fake_params: Dict[int, Parameter] = dict()
@ -148,9 +147,10 @@ class GeminiOptimizer(OptimizerWrapper):
for name, param in module.named_parameters():
if is_ddp_ignored(param):
if param.requires_grad:
warnings.warn(
self.logger.warning(
f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!"
"You should handle its optimizer update by yourself!",
ranks=[0],
)
else:
ddp_param_list.append(param)
@ -842,7 +842,9 @@ class GeminiOptimizer(OptimizerWrapper):
*args,
**kwargs,
) -> torch.Tensor:
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm")
self.logger.warning(
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
)
class GeminiAdamOptimizer(GeminiOptimizer):