From 385306ffd9b361b62fed3f9261fb25a1d5943a7f Mon Sep 17 00:00:00 2001 From: johnniang <1340692778@qq.com> Date: Wed, 20 Feb 2019 00:51:31 +0800 Subject: [PATCH] Customize base repository implementation --- src/main/java/cc/ryanc/halo/Application.java | 3 + .../repository/base/BaseRepositoryImpl.java | 91 +++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 src/main/java/cc/ryanc/halo/repository/base/BaseRepositoryImpl.java diff --git a/src/main/java/cc/ryanc/halo/Application.java b/src/main/java/cc/ryanc/halo/Application.java index 0f65eb888..691f25df1 100755 --- a/src/main/java/cc/ryanc/halo/Application.java +++ b/src/main/java/cc/ryanc/halo/Application.java @@ -1,11 +1,13 @@ package cc.ryanc.halo; +import cc.ryanc.halo.repository.base.BaseRepositoryImpl; import lombok.extern.slf4j.Slf4j; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cache.annotation.EnableCaching; import org.springframework.context.ApplicationContext; import org.springframework.data.jpa.repository.config.EnableJpaAuditing; +import org.springframework.data.jpa.repository.config.EnableJpaRepositories; /** *
@@ -19,6 +21,7 @@ import org.springframework.data.jpa.repository.config.EnableJpaAuditing; @SpringBootApplication @EnableCaching @EnableJpaAuditing +@EnableJpaRepositories(basePackages = "cc.ryanc.halo.repository", repositoryBaseClass = BaseRepositoryImpl.class) public class Application { public static void main(String[] args) { ApplicationContext context = SpringApplication.run(Application.class, args); diff --git a/src/main/java/cc/ryanc/halo/repository/base/BaseRepositoryImpl.java b/src/main/java/cc/ryanc/halo/repository/base/BaseRepositoryImpl.java new file mode 100644 index 000000000..0dd44b3a5 --- /dev/null +++ b/src/main/java/cc/ryanc/halo/repository/base/BaseRepositoryImpl.java @@ -0,0 +1,91 @@ +package cc.ryanc.halo.repository.base; + +import cc.ryanc.halo.logging.Logger; +import org.springframework.data.domain.Sort; +import org.springframework.data.jpa.domain.Specification; +import org.springframework.data.jpa.repository.support.JpaEntityInformation; +import org.springframework.data.jpa.repository.support.SimpleJpaRepository; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import javax.persistence.EntityManager; +import javax.persistence.TypedQuery; +import javax.persistence.criteria.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +/** + * Implementation of base repository. + * + * @paramdomain type + * @param id type + */ +public class BaseRepositoryImpl extends SimpleJpaRepository implements BaseRepository { + + private final Logger log = Logger.getLogger(getClass()); + + private final JpaEntityInformation entityInformation; + + private final EntityManager entityManager; + + public BaseRepositoryImpl(JpaEntityInformation entityInformation, EntityManager entityManager) { + super(entityInformation, entityManager); + this.entityInformation = entityInformation; + this.entityManager = entityManager; + } + + @Override + public List findAllByIdIn(Iterable ids, Sort sort) { + Assert.notNull(ids, "The given Iterable of Id's must not be null!"); + + log.debug("Customized findAllById method was invoked"); + + if (!ids.iterator().hasNext()) { + return Collections.emptyList(); + } + if (!this.entityInformation.hasCompositeId()) { + ByIdsSpecification specification = new ByIdsSpecification<>(this.entityInformation); + TypedQuery query = super.getQuery(specification, sort); + return query.setParameter(specification.parameter, ids).getResultList(); + } else { + List results = new ArrayList<>(); + + ids.forEach(id -> super.findById(id).ifPresent(results::add)); + + return results; + } + } + + @Override + public long deleteByIdIn(Iterable ids) { + + log.debug("Customized deleteByIdIn method was invoked"); + // Find all domains + List domains = findAllById(ids); + + // Delete in batch + deleteInBatch(domains); + + // Return the size of domain deleted + return domains.size(); + } + + private static final class ByIdsSpecification implements Specification { + private static final long serialVersionUID = 1L; + private final JpaEntityInformation entityInformation; + @Nullable + ParameterExpression parameter; + + ByIdsSpecification(JpaEntityInformation entityInformation) { + this.entityInformation = entityInformation; + } + + public Predicate toPredicate(Root root, CriteriaQuery> query, CriteriaBuilder cb) { + Path> path = root.get(this.entityInformation.getIdAttribute()); + this.parameter = cb.parameter(Iterable.class); + return path.in(this.parameter); + } + } +}