[booster] add low level zero plugin (#3594)

* [booster] add low level zero plugin

* [booster] fix gemini plugin test

* [booster] fix precision

* [booster] add low level zero plugin test

* [test] fix booster plugin test oom

* [test] fix booster plugin test oom

* [test] fix googlenet and inception output trans

* [test] fix diffuser clip vision model

* [test] fix torchaudio_wav2vec2_base

* [test] fix low level zero plugin test
pull/3641/head
Hongxin Liu 2023-04-26 14:37:25 +08:00 committed by GitHub
parent b9a8dff7e5
commit 4b3240cb59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 476 additions and 81 deletions

View File

@ -1,5 +1,6 @@
from .gemini_plugin import GeminiPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin']
__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']

View File

@ -0,0 +1,259 @@
import random
import warnings
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPCheckpointIO
__all__ = ['LowLevelZeroPlugin']
def _convert_to_fp16(x):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.half()
return x
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
class LowLevelZeroModel(ModelWrapper):
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
super().__init__(module)
self.convert_inputs = (precision == 'fp16')
module = zero_model_wrapper(module, zero_stage=stage)
if precision == 'fp16':
module = module.half()
module = module.to(get_current_device())
self.module = module
def forward(self, *args, **kwargs):
if self.convert_inputs:
args = tree_map(_convert_to_fp16, args)
kwargs = tree_map(_convert_to_fp16, kwargs)
return super().forward(*args, **kwargs)
class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(self,
module: nn.Module,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')
class LowLevelZeroPlugin(Plugin):
"""
Plugin for low level zero.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import LowLevelZeroPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = LowLevelZeroPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args:
strage (int, optional): ZeRO stage. Defaults to 1.
precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
reduce_bucket_size_in_m (int, optional): grad reduce bucket size in M. Defaults to 12.
communication_dtype (torch.dtype, optional): communication dtype. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True.
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
"""
def __init__(
self,
stage: int = 1,
precision: str = 'fp16',
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
cpu_offload: bool = False,
verbose: bool = False,
) -> None:
assert dist.is_initialized(
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.stage = stage
self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload)
self.optim_kwargs = dict(initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)
self.verbose = verbose
def support_no_sync(self) -> bool:
return False
def control_precision(self) -> bool:
return True
def supported_precisions(self) -> List[str]:
return ['fp16', 'fp32']
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Note:
1. Evaluation datasets should not be passed to this function.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()

View File

@ -55,6 +55,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# 2. contiguous gradients
# 3. cpu offload
# 4. support when some parameters requires_grad = False
# 5. support layer drop
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()

View File

@ -18,6 +18,7 @@ data_vae_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32))
data_unet_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32), timestep=3)
identity_output = lambda x: x
clip_vision_model_output = lambda x: dict(pooler_output=x[1])
def data_clip_model():
@ -65,7 +66,7 @@ model_zoo.register(name='diffusers_clip_text_model',
model_zoo.register(name='diffusers_clip_vision_model',
model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()),
data_gen_fn=data_clip_vision,
output_transform_fn=identity_output)
output_transform_fn=clip_vision_model_output)
model_zoo.register(name='diffusers_unet2d_model',
model_fn=diffusers.UNet2DModel,

View File

@ -1,3 +1,5 @@
from functools import partial
import torch
import torchaudio.models as tm
@ -101,13 +103,11 @@ def tacotron_data_gen_fn():
mel_specgram_lengths=mel_specgram_lengths)
model_zoo.register(
name='torchaudio_tacotron',
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
data_gen_fn=tacotron_data_gen_fn,
output_transform_fn=lambda outputs: dict(
spectrogram_before=outputs[0], spectrogram_after=outputs[1], stop_tokens=outputs[2], attn_weights=outputs[3]),
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='torchaudio_tacotron',
model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
data_gen_fn=tacotron_data_gen_fn,
output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)),
model_attribute=ModelAttribute(has_control_flow=True))
def wav2vec_data_gen_fn():
@ -118,7 +118,7 @@ def wav2vec_data_gen_fn():
model_zoo.register(name='torchaudio_wav2vec2_base',
model_fn=tm.wav2vec2_base,
model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0),
data_gen_fn=wav2vec_data_gen_fn,
output_transform_fn=transformer_output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -36,12 +36,12 @@ def swin_s():
# special output transform fn
google_net_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.GoogLeNetOutputs
) else dict(output=x)
google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs
) else dict(output=x)
swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
inception_v3_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.InceptionOutputs
) else dict(output=x)
inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs
) else dict(output=x)
model_zoo.register(name='torchvision_alexnet',
model_fn=tm.alexnet,

View File

@ -1,4 +1,5 @@
from contextlib import nullcontext
from typing import Optional
import torch
import torch.distributed as dist
@ -10,11 +11,53 @@ from colossalai.fx import is_compatible_with_meta
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.model.experimental import LazyInitContext
from colossalai.zero import ColoInitContext
from tests.kit.model_zoo import model_zoo
@parameterize('init_method', ['lazy', 'none', 'colo'])
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
try:
if init_method == 'colo':
ctx = ColoInitContext()
elif init_method == 'lazy':
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
for n, p in model.named_parameters():
assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter'
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
except Exception as e:
return repr(e)
# TODO(ver217): CI does not support lazy now
# @parameterize('init_method', ['lazy', 'none', 'colo'])
@parameterize('init_method', ['none'])
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
"""check gemini plugin over model zoo
@ -25,7 +68,6 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
if not is_support_meta and init_method == 'lazy':
return
from colossalai.utils.model.experimental import LazyInitContext
passed_models = []
failed_info = {} # (model_name, error) pair
@ -58,48 +100,16 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
]:
continue
try:
if init_method == 'colo':
ctx = ColoInitContext()
elif init_method == 'lazy':
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
for n, p in model.named_parameters():
assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter'
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
passed_models.append(name)
del booster, plugin, model, optimizer, criterion, data, output, loss
except Exception as e:
failed_info[name] = e
if early_stop:
raise e
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
else:
failed_info[name] = err
if early_stop:
break
if dist.get_rank() == 0:
print(f'Init method: {init_method}')
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
@ -140,7 +150,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 2, early_stop=early_stop)
spawn(run_dist, 4, early_stop=early_stop)
if __name__ == '__main__':

View File

@ -0,0 +1,122 @@
from typing import Optional
import torch
import torch.distributed as dist
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
# These models are not compatible with AMP
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`']
# These models have no parameters
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
# These models will get stuck
_STUCK_MODELS = [
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
'transformers_bert_for_pretraining', 'transformers_gpt_double_heads'
]
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
except Exception as e:
return repr(e)
@parameterize('stage', [2])
def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
"""check low level zero plugin over model zoo
Args:
stage (int), stage of low level zero plugin
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
passed_models = []
failed_info = {} # (model_name, error) pair
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
skipped_models = []
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
# FIXME(ver217): fix these models
if name in ignore_models:
skipped_models.append(name)
continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
else:
failed_info[name] = err
if early_stop:
break
if dist.get_rank() == 0:
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
print(f'Skipped models({len(skipped_models)}): {skipped_models}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
def check_dataloader_sharding():
plugin = LowLevelZeroPlugin()
# create a custom dasetset with 0 to 10
dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
# get the first batch of data
batch = next(iter(train_dataloader))[0].cuda()
is_rank_0 = dist.get_rank() == 0
if is_rank_0:
batch_to_compare = batch.clone()
else:
batch_to_compare = batch
# pass to the rank 1 value to rank 0
dist.broadcast(batch_to_compare, src=1)
# compare on rank 0
if is_rank_0:
assert not torch.equal(batch,
batch_to_compare), 'Same number was found across ranks but expected it to be different'
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_low_level_zero_plugin(early_stop=early_stop)
@rerun_if_address_is_in_use()
def test_low_level_zero_plugin(early_stop: bool = True):
spawn(run_dist, 2, early_stop=early_stop)
if __name__ == '__main__':
test_low_level_zero_plugin(early_stop=False)

View File

@ -11,36 +11,37 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def check_torch_ddp_plugin():
def run_fn(model_fn, data_gen_fn, output_transform_fn):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
assert isinstance(model.module, DDP)
assert isinstance(optimizer, OptimizerWrapper)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
def check_torch_ddp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
if name == 'dlrm_interactionarch':
continue
model = model_fn()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
assert isinstance(model.module, DDP)
assert isinstance(optimizer, OptimizerWrapper)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
run_fn(model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()
def check_dataloader_sharding():