diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py new file mode 100644 index 000000000..4021b3175 --- /dev/null +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -0,0 +1,72 @@ +import random + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from .plugin_base import Plugin + + +class DPPluginBase(Plugin): + """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation. + """ + + def __init__(self) -> None: + super().__init__() + assert dist.is_initialized( + ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + + def prepare_train_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + Note: + 1. Evaluation datasets should not be passed to this function. + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index dfdd7be26..fde8912a6 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,36 +1,25 @@ -import random -import warnings -from typing import Callable, List, Optional, Tuple, Union -from pathlib import Path -import os import logging +import os +import warnings +from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO -from colossalai.checkpoint_io.utils import save_state_dict +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero.gemini.memory_tracer import MemStats -from colossalai.checkpoint_io.utils import ( - get_base_filenames, - get_shard_filename - ) - -from colossalai.checkpoint_io import CheckpointIndexFile - -from .plugin_base import Plugin +from .dp_plugin_base import DPPluginBase __all__ = ['GeminiPlugin'] @@ -72,7 +61,13 @@ class GeminiCheckpointIO(GeneralCheckpointIO): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): + def save_sharded_model(self, + model: GeminiDDP, + checkpoint_path: str, + gather_dtensor: bool = False, + variant: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False): """ Save sharded model """ @@ -88,25 +83,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO): total_size = total_size + shard_pair[1] for key in shard.keys(): index_file.append_weight_map(key, shard_file) - + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors) - + index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( - f"The model is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + logging.info(f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") - - def load_sharded_model(self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + def load_sharded_model(self, + model: GeminiDDP, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False): """ load shard model, load model from multiple files """ return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) + class GeminiModel(ModelWrapper): def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None: @@ -148,7 +145,7 @@ class GeminiOptimizer(OptimizerWrapper): raise NotImplementedError('Gemini does not support clip_grad_by_value') -class GeminiPlugin(Plugin): +class GeminiPlugin(DPPluginBase): """ Plugin for Gemini. @@ -217,11 +214,7 @@ class GeminiPlugin(Plugin): norm_type: float = 2.0, verbose: bool = False, ) -> None: - - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() + super().__init__() self.gemini_config = dict( device=(device or get_current_device()), placement_policy=placement_policy, @@ -260,57 +253,6 @@ class GeminiPlugin(Plugin): def supported_devices(self) -> List[str]: return ['cuda'] - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - Note: - 1. Evaluation datasets should not be passed to this function. - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) - def configure( self, model: nn.Module, diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 969c430bd..828d8b274 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,24 +1,20 @@ -import random import warnings from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from colossalai.checkpoint_io import CheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device from colossalai.zero import zero_model_wrapper, zero_optim_wrapper -from .plugin_base import Plugin +from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO __all__ = ['LowLevelZeroPlugin'] @@ -88,7 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): raise NotImplementedError('LowLevelZero does not support clip_grad_by_value') -class LowLevelZeroPlugin(Plugin): +class LowLevelZeroPlugin(DPPluginBase): """ Plugin for low level zero. @@ -142,15 +138,10 @@ class LowLevelZeroPlugin(Plugin): cpu_offload: bool = False, verbose: bool = False, ) -> None: - - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + super().__init__() assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training' - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - self.stage = stage self.precision = precision self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, @@ -183,57 +174,6 @@ class LowLevelZeroPlugin(Plugin): def supported_devices(self) -> List[str]: return ['cuda'] - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - Note: - 1. Evaluation datasets should not be passed to this function. - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) - def configure( self, model: nn.Module, diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index c5e310c7e..d30d266c0 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,21 +1,16 @@ -import random from typing import Callable, List, Tuple, Union -import numpy as np -import torch -import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper -from .plugin_base import Plugin +from .dp_plugin_base import DPPluginBase __all__ = ['TorchDDPPlugin'] @@ -66,7 +61,7 @@ class TorchDDPModel(ModelWrapper): return self.module.module -class TorchDDPPlugin(Plugin): +class TorchDDPPlugin(DPPluginBase): """ Plugin for PyTorch DDP. @@ -97,11 +92,7 @@ class TorchDDPPlugin(Plugin): check_reduction: bool = False, gradient_as_bucket_view: bool = False, static_graph: bool = False) -> None: - - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() + super().__init__() self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers, bucket_cap_mb=bucket_cap_mb, find_unused_parameters=find_unused_parameters, @@ -124,57 +115,6 @@ class TorchDDPPlugin(Plugin): def supported_devices(self) -> List[str]: return ['cuda'] - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - Note: - 1. Evaluation datasets should not be passed to this function. - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) - def configure( self, model: nn.Module, diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py new file mode 100644 index 000000000..a2b94ba6c --- /dev/null +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -0,0 +1,85 @@ +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader, TensorDataset + +import colossalai +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase +from colossalai.checkpoint_io import CheckpointIO +from colossalai.interface import OptimizerWrapper +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +class DPPluginWrapper(DPPluginBase): + """This is a wrapper class for testing DP plugin initialization and dataloader creation. + """ + + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + pass + + def control_checkpoint_io(self) -> bool: + pass + + def control_device(self) -> bool: + pass + + def control_precision(self) -> bool: + pass + + def get_checkpoint_io(self) -> CheckpointIO: + pass + + def support_no_sync(self) -> bool: + pass + + def supported_devices(self) -> List[str]: + pass + + def supported_precisions(self) -> List[str]: + pass + + +def check_dataloader_sharding(): + plugin = DPPluginWrapper() + + # create a custom dasetset with 0 to 10 + dataset = TensorDataset(torch.arange(0, 10)) + train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + + # get the first batch of data + batch = next(iter(train_dataloader))[0].cuda() + is_rank_0 = dist.get_rank() == 0 + + if is_rank_0: + batch_to_compare = batch.clone() + else: + batch_to_compare = batch + # pass to the rank 1 value to rank 0 + dist.broadcast(batch_to_compare, src=1) + + # compare on rank 0 + if is_rank_0: + assert not torch.equal(batch, + batch_to_compare), 'Same number was found across ranks but expected it to be different' + + +def run_dist(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_dataloader_sharding() + + +@rerun_if_address_is_in_use() +def test_dp_plugin_dataloader(): + spawn(run_dist, 2) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 985d7989f..c7b3676fb 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -117,34 +117,9 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) -def check_dataloader_sharding(): - plugin = GeminiPlugin() - - # create a custom dasetset with 0 to 10 - dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) - - # get the first batch of data - batch = next(iter(train_dataloader))[0].cuda() - is_rank_0 = dist.get_rank() == 0 - - if is_rank_0: - batch_to_compare = batch.clone() - else: - batch_to_compare = batch - # pass to the rank 1 value to rank 0 - dist.broadcast(batch_to_compare, src=1) - - # compare on rank 0 - if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' - - def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - check_dataloader_sharding() check_gemini_plugin(early_stop=early_stop) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index e24196a14..d84b96f77 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -83,30 +83,6 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) -def check_dataloader_sharding(): - plugin = LowLevelZeroPlugin() - - # create a custom dasetset with 0 to 10 - dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) - - # get the first batch of data - batch = next(iter(train_dataloader))[0].cuda() - is_rank_0 = dist.get_rank() == 0 - - if is_rank_0: - batch_to_compare = batch.clone() - else: - batch_to_compare = batch - # pass to the rank 1 value to rank 0 - dist.broadcast(batch_to_compare, src=1) - - # compare on rank 0 - if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' - - def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 5354eae01..30c4db123 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -44,57 +44,9 @@ def check_torch_ddp_plugin(): torch.cuda.empty_cache() -def check_dataloader_sharding(): - plugin = TorchDDPPlugin() - - # create a custom dasetset with 0 to 10 - dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) - - # get the first batch of data - batch = next(iter(train_dataloader))[0].cuda() - is_rank_0 = dist.get_rank() == 0 - - if is_rank_0: - batch_to_compare = batch.clone() - else: - batch_to_compare = batch - # pass to the rank 1 value to rank 0 - dist.broadcast(batch_to_compare, src=1) - - # compare on rank 0 - if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' - - -def check_checkpoint_save_and_load(): - model_fn, data_gen_fn, output_transform_fn, _ = model_zoo['timm_resnet'] - - plugin = TorchDDPPlugin() - booster = Booster(plugin=plugin) - - model = model_fn() - optimizer = SGD(model.parameters(), lr=1e-3) - criterion = lambda x: x.mean() - data = data_gen_fn() - - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - output = model(**data) - output = output_transform_fn(output) - output_key = list(output.keys())[0] - loss = criterion(output[output_key]) - - booster.backward(loss, optimizer) - - def run_dist(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - check_dataloader_sharding() check_torch_ddp_plugin()