pull/1438/head
Jiarui Fang 2 years ago committed by GitHub
parent 039b7ed3bc
commit c9427a323f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,7 @@ from .cache_mgr import CachedParamMgr
from torch.nn.parameter import Parameter
from .._utils import dual_all_to_all
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup, ColoTensorSpec
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
@ -57,13 +57,15 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag):
self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index
if _weight is None:
self._weight.process_group = ProcessGroup(tp_degree=self.world_size)
colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size),
dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]),
compute_attr=ComputePattern.TP1D)
self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings,
self.embedding_dim_per_partition,
device='cpu',
dtype=dtype),
requires_grad=True,
spec=ShardSpec(dims=[-1], num_partitions=[self.world_size]))
spec=colo_tensor_spec)
self.init_parameters()
else:
assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter"

Loading…
Cancel
Save