[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)

* sharded optimizer checkpoint for gemini plugin

* modify test to reduce testing time

* update doc

* fix bug when keep_gatherd is true under GeminiPlugin
pull/4305/head
Baizhou Zhang 2023-07-21 14:39:01 +08:00 committed by GitHub
parent fc5cef2c79
commit c6f6005990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 289 additions and 84 deletions

View File

@ -1,3 +1,4 @@
import gc
import logging
import os
import warnings
@ -12,11 +13,19 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict,
save_state_dict,
save_state_dict_shards,
)
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini import ZeroOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase
@ -37,7 +46,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, this must be called on all processes.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
@ -54,7 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, this must be called on all processes.
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank.
"""
state_dict = optimizer.state_dict()
@ -76,7 +85,8 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save sharded model
Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
@ -86,28 +96,24 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
if not self.coordinator.is_master():
continue
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
index_file.append_meta_data("total_size", total_size)
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors)
# only save the index file on the master rank
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_sharded_model(self,
model: GeminiDDP,
@ -115,7 +121,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
strict: bool = False,
use_safetensors: bool = False):
"""
load shard model, load model from multiple files
Load shard model, load model from multiple files.
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
@ -125,16 +131,93 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()
assert isinstance(optimizer, ZeroOptimizer)
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
param_groups = optimizer.get_param_groups_for_saving()
torch.save(param_groups, group_file_path)
# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=is_master,
use_safetensors=False)
# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
# TODO(Baizhou): To be implemented.
pass
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()
assert isinstance(optimizer, ZeroOptimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
# Load param_groups.
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory.')
saved_param_groups = torch.load(param_group_path)
optimizer.load_param_groups(saved_param_groups)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()
optimizer.optimizer_loading_epilogue()
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
class GeminiModel(ModelWrapper):

View File

@ -5,6 +5,7 @@ from functools import reduce
from pathlib import Path
from typing import Iterator, Optional, OrderedDict, Tuple
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
@ -16,7 +17,6 @@ from .utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
has_index_file,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
@ -25,6 +25,7 @@ from .utils import (
load_states_into_optimizer,
save_param_groups,
save_state_dict,
save_state_dict_shards,
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
@ -122,15 +123,13 @@ class GeneralCheckpointIO(CheckpointIO):
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
@ -172,18 +171,17 @@ class GeneralCheckpointIO(CheckpointIO):
# shard checkpoint
state_dict = model.state_dict()
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)

View File

@ -1,4 +1,5 @@
# coding=utf-8
import os
import re
from collections import abc as container_abcs
from collections import defaultdict
@ -103,6 +104,43 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False) -> int:
'''
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is master.
use_safetensors (bool): Whether to use safetensors to save checkpoint.
Returns:
int: the total size of shards
'''
total_size = 0
for idx, shard_pair in enumerate(sharded_state_dict):
if not is_master:
continue
shard, current_size = shard_pair
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
return total_size
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a

View File

@ -3,7 +3,7 @@ import copy
import gc
import math
import warnings
from typing import Any, Dict, Set, Tuple
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
import torch
import torch.distributed as dist
@ -11,8 +11,10 @@ from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
@ -360,10 +362,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
chunk_offset = begin_in_chunk
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
if chunk.keep_gathered:
shard_offset = 0
else:
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
shard_size = end_in_chunk - begin_in_chunk
assert chunk_offset >= 0 and shard_offset >= 0
return chunk_offset, shard_offset, shard_size
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
@ -427,7 +431,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
dtype=torch.float32,
requires_grad=False).cpu()
else:
collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu()
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
return collected_states
# Check whether the param with given id is managed by current process.
@ -536,6 +541,31 @@ class ZeroOptimizer(ColossalaiOptimizer):
target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size])
next_state_offset += shard_size
def get_param_groups_for_saving(self) -> list:
'''
Return the param_groups in Pytorch format when saving to checkpoint.
'''
param_groups = copy.deepcopy(self.param_groups_backup)
# To be compatible with pytorch checkpointing,
# store extra hyperparameters used by pytorch Adam optimizer.
torch_special_hyperparameters = {
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'differentiable': False,
'fused': False
}
for group in param_groups:
for k, v in torch_special_hyperparameters.items():
if k not in group:
group[k] = v
return param_groups
def state_dict(self, only_rank_0: bool = True) -> dict:
"""
Args:
@ -555,21 +585,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
so it should be called only when memory resources are abundant.
"""
state_dict = {}
state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup)
torch_special_hyperparameters = {
'amsgrad': False,
'maximize': False,
'foreach': None,
'capturable': False,
'differentiable': False,
'fused': False
}
for group in state_dict['param_groups']:
for k, v in torch_special_hyperparameters.items():
if k not in group:
group[k] = v
state_dict['param_groups'] = self.get_param_groups_for_saving()
# Collect optimizer states.
state_dict['state'] = dict()
@ -634,8 +650,24 @@ class ZeroOptimizer(ColossalaiOptimizer):
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)
def load_param_states(self, param_states: dict):
"""Loads param states from a state_dict. The param_states can be complete or sharded.
During loading, filter out the part of states not considered by current process.
Args:
param_states (dict): A mapping from param_id to its states.
"""
for param_id, states in param_states.items():
if param_id in self.id_to_fake_params:
self.load_single_param_states(param_id, states)
def optimizer_loading_epilogue(self):
# Epilogue when loading state_dict to pytorch optimizer.
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
self.optim.defaults.setdefault('differentiable', False)
def load_state_dict(self, state_dict: dict):
"""Loads optimizer state from whole optimizer state_dict.
"""Loads optimizer state from complete optimizer state_dict.
During loading, filter out the part of states not considered by current process.
Args:
@ -643,17 +675,71 @@ class ZeroOptimizer(ColossalaiOptimizer):
from a call to :meth:`state_dict`.
"""
assert 'param_groups' in state_dict
assert 'state' in state_dict
self.load_param_groups(state_dict['param_groups'])
self.load_param_states(state_dict['state'])
self.optimizer_loading_epilogue()
state = state_dict['state']
def state_shard(self,
prefix: str = '',
max_shard_size: int = 1024,
only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing shards of optimizer states one by one.
The max size of each dictionary shard is specified by ``max_shard_size``.
for param_id, param_states in state.items():
if param_id in self.id_to_fake_params:
self.load_single_param_states(param_id, param_states)
Args:
prefix (str, optional): the prefix for states. Default to ''.
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected
only on rank 0, dafault to True.
# Epilogue for pytorch optimizer.
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
self.optim.defaults.setdefault('differentiable', False)
Yields:
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
"""
current_block = {}
current_block_size = 0
for param_id in self.id_to_real_params.keys():
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
ret_block = None
ret_block_size = 0
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor:
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[param_id] = state
current_block_size += state_size
if ret_block != None:
yield ret_block, ret_block_size
yield current_block, current_block_size
class GeminiAdamOptimizer(ZeroOptimizer):

View File

@ -21,10 +21,13 @@ Plugin is an important component that manages parallel configuration (eg: The ge
**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines.
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines.
**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.
### API of booster
{{ autodoc:colossalai.booster.Booster }}

View File

@ -21,8 +21,6 @@ Model must be boosted by `colossalai.booster.Booster` before loading. It will de
## Optimizer Checkpoint
> ⚠ Saving optimizer checkpoint in a sharded way is not supported yet.
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
Optimizer must be boosted by `colossalai.booster.Booster` before saving.

View File

@ -51,8 +51,6 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.
### Torch DDP Plugin
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).

View File

@ -24,10 +24,13 @@ Booster 插件是管理并行配置的重要组件eggemini 插件封装了
**_GeminiPlugin:_** GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。
**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了 DDP 加速方案,实现了模型级别的数据并行,可以跨多机运行。
**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。
**_LowLevelZeroPlugin:_** LowLevelZeroPlugin 插件封装了零冗余优化器的 1/2 阶段。阶段 1切分优化器参数分发到各并发进程或并发 GPU 上。阶段 2切分优化器参数及梯度分发到各并发进程或并发 GPU 上。
**_TorchFSDPPlugin:_** TorchFSDPPlugin封装了 Pytorch的FSDP加速方案可以用于零冗余优化器数据并行ZeroDP的训练。
### Booster 接口
<!--TODO: update autodoc -->

View File

@ -21,7 +21,6 @@
## 优化器 Checkpoint
> ⚠ 尚不支持以分片方式保存优化器 Checkpoint。
{{ autodoc:colossalai.booster.Booster.save_optimizer }}

View File

@ -51,7 +51,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。
### Torch DDP 插件

View File

@ -52,7 +52,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
@clear_cache_before_run()
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('shard', [False])
@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int):
@ -117,7 +117,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)

View File

@ -19,7 +19,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize('shard', [False])
@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
def exam_torch_load_from_gemini(shard: bool, model_name: str):
@ -83,7 +83,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
@clear_cache_before_run()
@parameterize('shard', [False])
@parameterize('shard', [False, True])
@parameterize('model_name', ['transformers_gpt'])
def exam_gemini_load_from_torch(shard: bool, model_name: str):
@ -165,7 +165,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)