mirror of https://github.com/hpcaitech/ColossalAI
[misc] Use dist logger in plugins (#6011)
* use dist logger in plugins * remove trash * print on rank 0 --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>pull/6022/head
parent
f1c3266a94
commit
dcc44aab8d
|
@ -1,4 +1,3 @@
|
||||||
import warnings
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
@ -8,6 +7,8 @@ from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
SUPPORT_PEFT = False
|
SUPPORT_PEFT = False
|
||||||
try:
|
try:
|
||||||
import peft
|
import peft
|
||||||
|
@ -81,12 +82,15 @@ class Booster:
|
||||||
plugin, Plugin
|
plugin, Plugin
|
||||||
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
|
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
|
||||||
self.plugin = plugin
|
self.plugin = plugin
|
||||||
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
# set accelerator
|
# set accelerator
|
||||||
if self.plugin and self.plugin.control_device():
|
if self.plugin and self.plugin.control_device():
|
||||||
self.accelerator = None
|
self.accelerator = None
|
||||||
if device is not None:
|
if device is not None:
|
||||||
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
|
self.logger.warning(
|
||||||
|
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
device = device or "cuda"
|
device = device or "cuda"
|
||||||
self.accelerator = Accelerator(device)
|
self.accelerator = Accelerator(device)
|
||||||
|
@ -94,7 +98,10 @@ class Booster:
|
||||||
# set precision
|
# set precision
|
||||||
if self.plugin and self.plugin.control_precision():
|
if self.plugin and self.plugin.control_precision():
|
||||||
if mixed_precision is not None:
|
if mixed_precision is not None:
|
||||||
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
|
self.logger.warning(
|
||||||
|
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
|
||||||
|
ranks=[0],
|
||||||
|
)
|
||||||
self.mixed_precision = None
|
self.mixed_precision = None
|
||||||
elif mixed_precision is None:
|
elif mixed_precision is None:
|
||||||
self.mixed_precision = None
|
self.mixed_precision = None
|
||||||
|
@ -267,8 +274,9 @@ class Booster:
|
||||||
), "Please provide pretrained directory path if not passing in lora configuration."
|
), "Please provide pretrained directory path if not passing in lora configuration."
|
||||||
if quantize is True:
|
if quantize is True:
|
||||||
if bnb_quantization_config is not None:
|
if bnb_quantization_config is not None:
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
|
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
bnb_quantization_config = BnbQuantizationConfig(
|
bnb_quantization_config = BnbQuantizationConfig(
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import gc
|
import gc
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -27,6 +26,7 @@ from colossalai.checkpoint_io.utils import (
|
||||||
)
|
)
|
||||||
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||||
|
@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||||
"""
|
"""
|
||||||
|
@ -118,7 +119,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
||||||
if os.path.isfile(checkpoint_path):
|
if os.path.isfile(checkpoint_path):
|
||||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0])
|
||||||
return
|
return
|
||||||
|
|
||||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -143,10 +144,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
save_config_file(model.unwrap(), checkpoint_path)
|
save_config_file(model.unwrap(), checkpoint_path)
|
||||||
logging.info(
|
self.logger.info(
|
||||||
f"The model is split into checkpoint shards. "
|
f"The model is split into checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_model(
|
def load_sharded_model(
|
||||||
|
@ -168,7 +170,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||||
|
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
|
||||||
return
|
return
|
||||||
|
|
||||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -201,10 +203,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
logging.info(
|
self.logger.info(
|
||||||
f"The optimizer is going to be split to checkpoint shards. "
|
f"The optimizer is going to be split to checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
|
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
|
||||||
|
@ -214,7 +217,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||||
if not os.path.isfile(checkpoint_index_file):
|
if not os.path.isfile(checkpoint_index_file):
|
||||||
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
|
self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0])
|
||||||
|
|
||||||
assert isinstance(optimizer, GeminiOptimizer)
|
assert isinstance(optimizer, GeminiOptimizer)
|
||||||
|
|
||||||
|
@ -369,9 +372,12 @@ class GeminiPlugin(DPPluginBase):
|
||||||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||||
if get_accelerator().name == "npu":
|
if get_accelerator().name == "npu":
|
||||||
assert placement_policy == "static", "NPU only supports static placement policy"
|
assert placement_policy == "static", "NPU only supports static placement policy"
|
||||||
|
|
||||||
|
self.logger = get_dist_logger()
|
||||||
if enable_async_reduce and not pin_memory:
|
if enable_async_reduce and not pin_memory:
|
||||||
logging.warning(
|
self.logger.warning(
|
||||||
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set."
|
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
pin_memory = True
|
pin_memory = True
|
||||||
self.gemini_config = dict(
|
self.gemini_config = dict(
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import ctypes
|
import ctypes
|
||||||
import random
|
import random
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -27,6 +26,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
@ -1023,6 +1023,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
inner_ring_size: int = None,
|
inner_ring_size: int = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||||
|
@ -1040,8 +1041,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
tp_size > 1
|
tp_size > 1
|
||||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
|
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
|
||||||
if sp_size != 1:
|
if sp_size != 1:
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
|
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||||
|
@ -1126,7 +1128,12 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if sequence_parallelism_mode == "ring_attn":
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
assert parallel_output, "Ring Attention doesn't support gathering output yet."
|
if not parallel_output:
|
||||||
|
self.logger.warning(
|
||||||
|
"parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.",
|
||||||
|
ranks=[0],
|
||||||
|
)
|
||||||
|
parallel_output = True
|
||||||
|
|
||||||
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||||
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||||
|
@ -1231,7 +1238,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
optimizer = cast_to_distributed(optimizer)
|
optimizer = cast_to_distributed(optimizer)
|
||||||
|
|
||||||
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
||||||
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
|
self.logger.warning(
|
||||||
|
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
||||||
|
ranks=[0],
|
||||||
|
)
|
||||||
zero_config["partition_grad"] = False
|
zero_config["partition_grad"] = False
|
||||||
zero_stage = 0
|
zero_stage = 0
|
||||||
|
|
||||||
|
@ -1287,9 +1297,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
else:
|
else:
|
||||||
is_zero = self.dp_size > 1
|
is_zero = self.dp_size > 1
|
||||||
if self.dp_size == 1:
|
if self.dp_size == 1:
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||||
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
"If you do not intend to use cpu_offload, please consider set zero_stage=0.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||||
|
@ -1332,7 +1343,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
|
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
|
||||||
|
|
||||||
if return_outputs:
|
if return_outputs:
|
||||||
warnings.warn("return_outputs may lead to significant extra memory consumption.")
|
self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0])
|
||||||
|
|
||||||
# Create a context for gradient synchronization based on the optimizer type.
|
# Create a context for gradient synchronization based on the optimizer type.
|
||||||
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
|
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
|
||||||
|
@ -1346,10 +1357,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# run with gradients accumulation
|
# run with gradients accumulation
|
||||||
if (
|
if model.require_grad_sync == False or (
|
||||||
model.require_grad_sync == False
|
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||||
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
|
|
||||||
or not torch.is_grad_enabled()
|
|
||||||
):
|
):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -1449,7 +1458,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
|
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
|
||||||
assert self.pp_size == 1 and self.tp_size == 1
|
assert self.pp_size == 1 and self.tp_size == 1
|
||||||
self.lora_enabled = True
|
self.lora_enabled = True
|
||||||
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
|
||||||
|
|
||||||
if bnb_quantization_config is not None:
|
if bnb_quantization_config is not None:
|
||||||
model = quantize_model(model, bnb_quantization_config)
|
model = quantize_model(model, bnb_quantization_config)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import enum
|
import enum
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -33,6 +31,7 @@ from colossalai.checkpoint_io.utils import (
|
||||||
)
|
)
|
||||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
|
@ -62,9 +61,7 @@ class OptimizerParamCheckState(enum.Enum):
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||||
def __init__(
|
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
|
||||||
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
|
|
||||||
) -> None:
|
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.dtype = None
|
self.dtype = None
|
||||||
if precision == "fp16":
|
if precision == "fp16":
|
||||||
|
@ -76,7 +73,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||||
module = module.to(get_accelerator().get_current_device())
|
module = module.to(get_accelerator().get_current_device())
|
||||||
self.module = module
|
self.module = module
|
||||||
self.convert_fn = None
|
self.convert_fn = None
|
||||||
if self.dtype is not None and cast_inputs:
|
if self.dtype is not None:
|
||||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||||
self.overlap_allgather = overlap_allgather
|
self.overlap_allgather = overlap_allgather
|
||||||
if overlap_allgather:
|
if overlap_allgather:
|
||||||
|
@ -140,7 +137,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
|
||||||
return
|
return
|
||||||
|
|
||||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -177,10 +174,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
logging.info(
|
self.logger.info(
|
||||||
f"The optimizer is going to be split to checkpoint shards. "
|
f"The optimizer is going to be split to checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
||||||
|
@ -267,7 +265,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
|
|
||||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
|
||||||
return
|
return
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
|
@ -336,7 +334,6 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
master_weights: bool = True,
|
master_weights: bool = True,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
cast_inputs: bool = True,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||||
|
@ -363,8 +360,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
)
|
)
|
||||||
self.lora_enabled = False
|
self.lora_enabled = False
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.cast_inputs = cast_inputs
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
# set class name with stage, for better error message
|
# set class name with stage, for better error message
|
||||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||||
|
|
||||||
|
@ -400,7 +396,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
|
|
||||||
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
||||||
self.lora_enabled = True
|
self.lora_enabled = True
|
||||||
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
|
||||||
|
|
||||||
if bnb_quantization_config is not None:
|
if bnb_quantization_config is not None:
|
||||||
model = quantize_model(model, bnb_quantization_config)
|
model = quantize_model(model, bnb_quantization_config)
|
||||||
|
@ -449,8 +445,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
origin_param = name2param[origin_key]
|
origin_param = name2param[origin_key]
|
||||||
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
||||||
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
|
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||||
|
@ -478,10 +475,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
model = LowLevelZeroModel(
|
model = LowLevelZeroModel(
|
||||||
model,
|
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
|
||||||
self.precision,
|
|
||||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
|
||||||
cast_inputs=self.cast_inputs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Support Galore + ZeRO
|
# TODO: Support Galore + ZeRO
|
||||||
|
@ -493,7 +487,10 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
optimizer = cast_to_distributed(optimizer)
|
optimizer = cast_to_distributed(optimizer)
|
||||||
|
|
||||||
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
|
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
|
||||||
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
|
self.logger.warning(
|
||||||
|
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
||||||
|
ranks=[0],
|
||||||
|
)
|
||||||
zero_optim_kwargs["partition_grad"] = False
|
zero_optim_kwargs["partition_grad"] = False
|
||||||
zero_stage = 0
|
zero_stage = 0
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Callable, List, Optional, OrderedDict, Tuple
|
from typing import Callable, List, Optional, OrderedDict, Tuple
|
||||||
|
@ -26,6 +25,7 @@ from colossalai.checkpoint_io import MoECheckpointIO
|
||||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import cast_to_distributed
|
from colossalai.nn.optimizer import cast_to_distributed
|
||||||
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
||||||
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
||||||
|
@ -215,12 +215,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
overlap_p2p: bool = True,
|
overlap_p2p: bool = True,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.logger = get_dist_logger()
|
||||||
if overlap_communication or zero_stage == 2:
|
if overlap_communication or zero_stage == 2:
|
||||||
overlap_communication = False
|
overlap_communication = False
|
||||||
zero_stage = 1
|
zero_stage = 1
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
f"overlap_communication and zero_stage are set to False and 1 because "
|
f"overlap_communication and zero_stage are set to False and 1 because "
|
||||||
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. "
|
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
@ -238,8 +240,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
tp_size > 1
|
tp_size > 1
|
||||||
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
|
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
|
||||||
if sp_size != 1:
|
if sp_size != 1:
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
|
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode},"
|
||||||
|
"will ignore the given sequence parallelism size.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
self.sp_size = 1
|
self.sp_size = 1
|
||||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||||
|
@ -400,8 +404,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
and self.sequence_parallelism_mode == "all_to_all"
|
and self.sequence_parallelism_mode == "all_to_all"
|
||||||
)
|
)
|
||||||
if use_ddp:
|
if use_ddp:
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated"
|
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
self.ddp_config["find_unused_parameters"] = True
|
self.ddp_config["find_unused_parameters"] = True
|
||||||
|
|
||||||
|
@ -457,9 +462,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.dp_size <= 1:
|
if self.dp_size <= 1:
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||||
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
"If you do not intend to use cpu_offload, please consider set zero_stage=0.",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||||
optimizer = MoeHybridParallelZeroOptimizer(
|
optimizer = MoeHybridParallelZeroOptimizer(
|
||||||
|
|
|
@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
|
||||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -21,6 +22,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
@ -30,6 +28,7 @@ from torch.utils.data import DataLoader
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
from .dp_plugin_base import DPPluginBase
|
from .dp_plugin_base import DPPluginBase
|
||||||
|
|
||||||
|
@ -40,6 +39,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
|
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
|
||||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||||
|
@ -88,7 +88,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
|
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
|
||||||
if os.path.isfile(checkpoint_path):
|
if os.path.isfile(checkpoint_path):
|
||||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
|
||||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -117,7 +117,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
utils.save_config_file(model.unwrap(), checkpoint_path)
|
utils.save_config_file(model.unwrap(), checkpoint_path)
|
||||||
logging.info(
|
self.logger.info(
|
||||||
f"The model is split into checkpoint shards. "
|
f"The model is split into checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
|
@ -162,7 +162,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
|
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
|
||||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -200,7 +200,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||||
|
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
logging.info(
|
self.logger.info(
|
||||||
f"The optimizer is going to be split to checkpoint shards. "
|
f"The optimizer is going to be split to checkpoint shards. "
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
|
@ -311,6 +311,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||||
param_init_fn=param_init_fn,
|
param_init_fn=param_init_fn,
|
||||||
sync_module_states=sync_module_states,
|
sync_module_states=sync_module_states,
|
||||||
)
|
)
|
||||||
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
|
||||||
|
@ -349,7 +350,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||||
|
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
if len(optimizer.param_groups) > 1:
|
if len(optimizer.param_groups) > 1:
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
|
"TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used."
|
||||||
)
|
)
|
||||||
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
|
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -136,7 +135,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||||
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
|
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.param_groups_backup = list()
|
self.param_groups_backup = list()
|
||||||
|
self.logger = get_dist_logger()
|
||||||
# Mapping from integer id to real/fake param tensor, used for checkpointing.
|
# Mapping from integer id to real/fake param tensor, used for checkpointing.
|
||||||
self.id_to_real_params: Dict[int, Parameter] = dict()
|
self.id_to_real_params: Dict[int, Parameter] = dict()
|
||||||
self.id_to_fake_params: Dict[int, Parameter] = dict()
|
self.id_to_fake_params: Dict[int, Parameter] = dict()
|
||||||
|
@ -148,9 +147,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||||
for name, param in module.named_parameters():
|
for name, param in module.named_parameters():
|
||||||
if is_ddp_ignored(param):
|
if is_ddp_ignored(param):
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
warnings.warn(
|
self.logger.warning(
|
||||||
f"Parameter `{name}` is ignored by DDP but requires gradient! "
|
f"Parameter `{name}` is ignored by DDP but requires gradient! "
|
||||||
"You should handle its optimizer update by yourself!"
|
"You should handle its optimizer update by yourself!",
|
||||||
|
ranks=[0],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ddp_param_list.append(param)
|
ddp_param_list.append(param)
|
||||||
|
@ -842,7 +842,9 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm")
|
self.logger.warning(
|
||||||
|
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GeminiAdamOptimizer(GeminiOptimizer):
|
class GeminiAdamOptimizer(GeminiOptimizer):
|
||||||
|
|
Loading…
Reference in New Issue