|
|
|
@ -81,9 +81,10 @@ class StrategyGenerator(ABC):
|
|
|
|
|
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
|
|
|
|
dim_size = len(logical_shape)
|
|
|
|
|
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
|
|
|
|
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
|
|
|
|
entire_shape=logical_shape,
|
|
|
|
|
dim_partition_dict=dim_partition_dict_element)
|
|
|
|
|
sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh,
|
|
|
|
|
entire_shape=logical_shape,
|
|
|
|
|
dim_partition_dict=dim_partition_dict_element)
|
|
|
|
|
sharding_spec.append(sharding_spec_element)
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(
|
|
|
|
|
op_data.data, torch.Tensor
|
|
|
|
@ -193,18 +194,40 @@ class StrategyGenerator(ABC):
|
|
|
|
|
Args:
|
|
|
|
|
strategy (ShardingStrategy): the ShardingStrategy generated.
|
|
|
|
|
key (str): the name of the operation data defined by the generator.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
op_data = self.op_data[key]
|
|
|
|
|
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
|
|
|
|
|
|
|
|
|
|
if len(sharded_shape) == 0:
|
|
|
|
|
num_elements = 1
|
|
|
|
|
def _compute_size_in_bytes_helper(sharding_spec, meta_data):
|
|
|
|
|
sharded_shape = sharding_spec.get_sharded_shape_per_device()
|
|
|
|
|
if len(sharded_shape) == 0:
|
|
|
|
|
num_elements = 1
|
|
|
|
|
else:
|
|
|
|
|
num_elements = reduce(operator.mul, sharded_shape)
|
|
|
|
|
dtype = getattr(meta_data, 'dtype')
|
|
|
|
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
|
|
|
|
return num_elements * size_per_elem_bytes
|
|
|
|
|
|
|
|
|
|
if isinstance(op_data.data, tuple):
|
|
|
|
|
assert isinstance(strategy.sharding_specs[op_data], list), \
|
|
|
|
|
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
|
|
|
|
|
total_bytes = 0
|
|
|
|
|
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
|
|
|
|
|
meta_data = op_data.data[index]
|
|
|
|
|
if isinstance(meta_data, torch.Tensor):
|
|
|
|
|
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
|
|
|
|
|
else:
|
|
|
|
|
# if meta_data is not a tensor, we count the memroy as 0
|
|
|
|
|
element_bytes = 0
|
|
|
|
|
total_bytes += element_bytes
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
num_elements = reduce(operator.mul, sharded_shape)
|
|
|
|
|
dtype = self.op_data[key].data.dtype
|
|
|
|
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
|
|
|
|
return num_elements * size_per_elem_bytes
|
|
|
|
|
if isinstance(op_data.data, torch.Tensor):
|
|
|
|
|
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
|
|
|
|
|
else:
|
|
|
|
|
# if op_data.data is not a tensor, we count the memroy as 0
|
|
|
|
|
total_bytes = 0
|
|
|
|
|
|
|
|
|
|
return total_bytes
|
|
|
|
|
|
|
|
|
|
def generate(self) -> List[ShardingStrategy]:
|
|
|
|
|
"""
|
|
|
|
|