2023-05-22 07:02:17 +00:00
|
|
|
from dataclasses import dataclass
|
|
|
|
|
2023-05-24 08:01:26 +00:00
|
|
|
__all__ = ['ShardConfig']
|
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ShardConfig:
|
2023-06-09 06:36:54 +00:00
|
|
|
r"""
|
|
|
|
The config for sharding the huggingface model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
rank (int): The rank of local process
|
|
|
|
world_size (int): The world size of the distributed process
|
|
|
|
gather_output (bool): Whether to gather the output of the model of the last layer
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
|
|
|
rank: int
|
|
|
|
world_size: int = 2
|
2023-06-09 06:36:54 +00:00
|
|
|
gather_output: bool = True
|