2023-05-22 07:02:17 +00:00
|
|
|
import torch
|
|
|
|
|
2023-05-24 02:26:46 +00:00
|
|
|
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
|
2023-05-24 08:01:26 +00:00
|
|
|
from .shard_config import ShardConfig
|
2023-05-22 07:02:17 +00:00
|
|
|
|
2023-06-12 05:56:09 +00:00
|
|
|
dim_mapping = {Col_Layer: 0, Row_Layer: 1}
|
2023-05-22 07:02:17 +00:00
|
|
|
|
2023-05-24 02:26:46 +00:00
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
class Slicer():
|
|
|
|
|
|
|
|
def __init__(
|
2023-05-24 02:26:46 +00:00
|
|
|
self,
|
|
|
|
shardconfig: ShardConfig #TODO
|
2023-05-22 07:02:17 +00:00
|
|
|
) -> None:
|
|
|
|
self.shardconfig = shardconfig
|
|
|
|
|
|
|
|
def slice_weight_bias(
|
|
|
|
self,
|
|
|
|
weight: torch.Tensor,
|
|
|
|
bias: torch.Tensor,
|
|
|
|
policy_layer_cls: Layer,
|
2023-06-07 08:09:40 +00:00
|
|
|
n_cast: int = None,
|
|
|
|
reversed: bool = False,
|
2023-05-22 07:02:17 +00:00
|
|
|
):
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Slice the weight and bias according to policy layer cls
|
2023-05-24 02:26:46 +00:00
|
|
|
``Layer`` -> do nothing
|
|
|
|
``Col_Layer`` -> slice the weight and bias along dim 1
|
|
|
|
``Row_Layer`` -> slice the weight along dim 0 and do not slice bias
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Args:
|
2023-05-24 02:26:46 +00:00
|
|
|
weight (:class:`torch.nn.Module`): The weight of the layer
|
|
|
|
bias: (:class:`torch.nn.Module`): The bias of the layer
|
|
|
|
policy_layer_class (:class:`Policy`): The class represent how to slice the tensor
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
|
|
|
if policy_layer_cls == Layer:
|
|
|
|
return weight, bias
|
2023-06-07 08:09:40 +00:00
|
|
|
|
|
|
|
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
|
|
|
|
# print(weight.shape, dim)
|
|
|
|
if policy_layer_cls == Col_Layer:
|
|
|
|
weight = self.slice_tensor(weight, dim, False, n_cast)
|
2023-06-12 05:56:09 +00:00
|
|
|
bias = self.slice_tensor(bias, 0, True, n_cast)
|
2023-05-22 07:02:17 +00:00
|
|
|
elif policy_layer_cls == Row_Layer:
|
2023-06-07 08:09:40 +00:00
|
|
|
weight = self.slice_tensor(weight, dim, False, n_cast)
|
2023-05-22 07:02:17 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
2023-06-07 08:09:40 +00:00
|
|
|
if reversed:
|
|
|
|
weight = weight.transpose(0, 1).contiguous()
|
2023-05-22 07:02:17 +00:00
|
|
|
return weight, bias
|
|
|
|
|
|
|
|
def slice_tensor(
|
|
|
|
self,
|
|
|
|
tensor_in: torch.Tensor,
|
|
|
|
dim: int,
|
|
|
|
is_bias: bool,
|
2023-06-07 08:09:40 +00:00
|
|
|
n_cast: int = None,
|
2023-05-22 07:02:17 +00:00
|
|
|
) -> torch.Tensor:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Slice tensor according to the config
|
2023-05-24 02:26:46 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
tensor_in (:class:`torch.Tensor`): The tensor to slice
|
|
|
|
dim (int): The dimension to slice
|
|
|
|
is_bias (bool): Whether the tensor is bias
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
|
|
|
if tensor_in is None:
|
|
|
|
return None
|
|
|
|
if not is_bias:
|
2023-06-07 08:09:40 +00:00
|
|
|
return self.slice_2d(tensor_in, dim, n_cast)
|
2023-05-22 07:02:17 +00:00
|
|
|
else:
|
2023-06-07 08:09:40 +00:00
|
|
|
return self.slice_1d(tensor_in, n_cast)
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
def slice_2d(
|
|
|
|
self,
|
|
|
|
tensor: torch.Tensor,
|
|
|
|
dim: int,
|
2023-06-07 08:09:40 +00:00
|
|
|
n_cast: int = None,
|
2023-05-22 07:02:17 +00:00
|
|
|
) -> torch.Tensor:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
|
|
|
Slice the 2D tensor
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Args:
|
2023-05-24 02:26:46 +00:00
|
|
|
tensor (:class:`torch.Tensor`): The tensor to slice
|
|
|
|
dim (int): The dimension to slice
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-05-24 02:26:46 +00:00
|
|
|
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
|
2023-05-22 07:02:17 +00:00
|
|
|
if dim == 0:
|
2023-06-07 08:09:40 +00:00
|
|
|
return self.slice_row(tensor, n_cast)
|
2023-05-22 07:02:17 +00:00
|
|
|
elif dim == 1:
|
2023-06-07 08:09:40 +00:00
|
|
|
return self.slice_col(tensor, n_cast)
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
def slice_1d(
|
|
|
|
self,
|
|
|
|
tensor: torch.Tensor,
|
2023-06-07 08:09:40 +00:00
|
|
|
n_cast: int = None,
|
2023-05-22 07:02:17 +00:00
|
|
|
) -> torch.Tensor:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
|
|
|
Slice the 1D tensor
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
Args:
|
2023-05-24 02:26:46 +00:00
|
|
|
tensor (:class:`torch.Tensor`): The tensor to slice
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
:class:`torch.Tensor`: The sliced tensor
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-07 08:09:40 +00:00
|
|
|
if n_cast is None:
|
|
|
|
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
|
|
|
else:
|
|
|
|
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
|
|
|
chunk_list = [
|
|
|
|
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
|
|
|
]
|
|
|
|
return torch.cat(chunk_list, dim=0).contiguous()
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
def slice_col(
|
|
|
|
self,
|
|
|
|
tensor: torch.Tensor,
|
2023-06-07 08:09:40 +00:00
|
|
|
n_cast: int = None,
|
2023-05-22 07:02:17 +00:00
|
|
|
) -> torch.Tensor:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Slice the tensor in column
|
|
|
|
|
|
|
|
Args:
|
2023-05-24 02:26:46 +00:00
|
|
|
tensor (:class:`torch.Tensor`): The tensor to slice
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
:class:`torch.Tensor`: The sliced tensor
|
|
|
|
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-07 08:09:40 +00:00
|
|
|
if n_cast is None:
|
2023-06-12 05:56:09 +00:00
|
|
|
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
2023-06-07 08:09:40 +00:00
|
|
|
else:
|
2023-06-12 05:56:09 +00:00
|
|
|
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
2023-06-07 08:09:40 +00:00
|
|
|
chunk_list = [
|
|
|
|
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
|
|
|
]
|
2023-06-12 05:56:09 +00:00
|
|
|
return torch.cat(chunk_list, dim=1).contiguous()
|
2023-05-22 07:02:17 +00:00
|
|
|
|
|
|
|
def slice_row(
|
|
|
|
self,
|
|
|
|
tensor: torch.Tensor,
|
2023-06-07 08:09:40 +00:00
|
|
|
n_cast: int = None,
|
2023-05-22 07:02:17 +00:00
|
|
|
) -> torch.Tensor:
|
2023-05-24 02:26:46 +00:00
|
|
|
r"""
|
2023-05-22 07:02:17 +00:00
|
|
|
Slice the tensor in column
|
|
|
|
|
|
|
|
Args:
|
2023-05-24 02:26:46 +00:00
|
|
|
tensor (:class:`torch.Tensor`): The tensor to slice
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
:class:`torch.Tensor`: The sliced tensor
|
2023-05-22 07:02:17 +00:00
|
|
|
"""
|
2023-06-07 08:09:40 +00:00
|
|
|
if n_cast is None:
|
2023-06-12 05:56:09 +00:00
|
|
|
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
2023-06-07 08:09:40 +00:00
|
|
|
else:
|
2023-06-12 05:56:09 +00:00
|
|
|
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
2023-06-07 08:09:40 +00:00
|
|
|
chunk_list = [
|
|
|
|
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
|
|
|
]
|
2023-06-12 05:56:09 +00:00
|
|
|
return torch.cat(chunk_list, dim=0).contiguous()
|