From 385306ffd9b361b62fed3f9261fb25a1d5943a7f Mon Sep 17 00:00:00 2001 From: johnniang <1340692778@qq.com> Date: Wed, 20 Feb 2019 00:51:31 +0800 Subject: [PATCH] Customize base repository implementation --- src/main/java/cc/ryanc/halo/Application.java | 3 + .../repository/base/BaseRepositoryImpl.java | 91 +++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 src/main/java/cc/ryanc/halo/repository/base/BaseRepositoryImpl.java diff --git a/src/main/java/cc/ryanc/halo/Application.java b/src/main/java/cc/ryanc/halo/Application.java index 0f65eb888..691f25df1 100755 --- a/src/main/java/cc/ryanc/halo/Application.java +++ b/src/main/java/cc/ryanc/halo/Application.java @@ -1,11 +1,13 @@ package cc.ryanc.halo; +import cc.ryanc.halo.repository.base.BaseRepositoryImpl; import lombok.extern.slf4j.Slf4j; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cache.annotation.EnableCaching; import org.springframework.context.ApplicationContext; import org.springframework.data.jpa.repository.config.EnableJpaAuditing; +import org.springframework.data.jpa.repository.config.EnableJpaRepositories; /** *
@@ -19,6 +21,7 @@ import org.springframework.data.jpa.repository.config.EnableJpaAuditing;
 @SpringBootApplication
 @EnableCaching
 @EnableJpaAuditing
+@EnableJpaRepositories(basePackages = "cc.ryanc.halo.repository", repositoryBaseClass = BaseRepositoryImpl.class)
 public class Application {
     public static void main(String[] args) {
         ApplicationContext context = SpringApplication.run(Application.class, args);
diff --git a/src/main/java/cc/ryanc/halo/repository/base/BaseRepositoryImpl.java b/src/main/java/cc/ryanc/halo/repository/base/BaseRepositoryImpl.java
new file mode 100644
index 000000000..0dd44b3a5
--- /dev/null
+++ b/src/main/java/cc/ryanc/halo/repository/base/BaseRepositoryImpl.java
@@ -0,0 +1,91 @@
+package cc.ryanc.halo.repository.base;
+
+import cc.ryanc.halo.logging.Logger;
+import org.springframework.data.domain.Sort;
+import org.springframework.data.jpa.domain.Specification;
+import org.springframework.data.jpa.repository.support.JpaEntityInformation;
+import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+
+import javax.persistence.EntityManager;
+import javax.persistence.TypedQuery;
+import javax.persistence.criteria.*;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Implementation of base repository.
+ *
+ * @param  domain type
+ * @param      id type
+ */
+public class BaseRepositoryImpl extends SimpleJpaRepository implements BaseRepository {
+
+    private final Logger log = Logger.getLogger(getClass());
+
+    private final JpaEntityInformation entityInformation;
+
+    private final EntityManager entityManager;
+
+    public BaseRepositoryImpl(JpaEntityInformation entityInformation, EntityManager entityManager) {
+        super(entityInformation, entityManager);
+        this.entityInformation = entityInformation;
+        this.entityManager = entityManager;
+    }
+
+    @Override
+    public List findAllByIdIn(Iterable ids, Sort sort) {
+        Assert.notNull(ids, "The given Iterable of Id's must not be null!");
+
+        log.debug("Customized findAllById method was invoked");
+
+        if (!ids.iterator().hasNext()) {
+            return Collections.emptyList();
+        }
+        if (!this.entityInformation.hasCompositeId()) {
+            ByIdsSpecification specification = new ByIdsSpecification<>(this.entityInformation);
+            TypedQuery query = super.getQuery(specification, sort);
+            return query.setParameter(specification.parameter, ids).getResultList();
+        } else {
+            List results = new ArrayList<>();
+
+            ids.forEach(id -> super.findById(id).ifPresent(results::add));
+
+            return results;
+        }
+    }
+
+    @Override
+    public long deleteByIdIn(Iterable ids) {
+
+        log.debug("Customized deleteByIdIn method was invoked");
+        // Find all domains
+        List domains = findAllById(ids);
+
+        // Delete in batch
+        deleteInBatch(domains);
+
+        // Return the size of domain deleted
+        return domains.size();
+    }
+
+    private static final class ByIdsSpecification implements Specification {
+        private static final long serialVersionUID = 1L;
+        private final JpaEntityInformation entityInformation;
+        @Nullable
+        ParameterExpression parameter;
+
+        ByIdsSpecification(JpaEntityInformation entityInformation) {
+            this.entityInformation = entityInformation;
+        }
+
+        public Predicate toPredicate(Root root, CriteriaQuery query, CriteriaBuilder cb) {
+            Path path = root.get(this.entityInformation.getIdAttribute());
+            this.parameter = cb.parameter(Iterable.class);
+            return path.in(this.parameter);
+        }
+    }
+}