From 521017c5c2a12b640e87d1d0192443ba85c7267b Mon Sep 17 00:00:00 2001
From: Justin Richer <jricher@mit.edu>
Date: Wed, 16 Apr 2014 21:39:37 -0400
Subject: [PATCH] updated stats service to have a resettable cache triggered by
 other service events

---
 .../openid/connect/service/StatsService.java  | 11 ++--
 ...faultOAuth2ClientDetailsEntityService.java | 12 ++++-
 .../web/OAuthConfirmationController.java      |  2 +-
 .../impl/DefaultApprovedSiteService.java      | 14 +++--
 .../service/impl/DefaultStatsService.java     | 54 ++++++++++++-------
 .../openid/connect/web/ManagerController.java |  4 +-
 .../mitre/openid/connect/web/StatsAPI.java    |  6 +--
 .../service/impl/TestDefaultStatsService.java | 53 +++++++++++++++---
 8 files changed, 116 insertions(+), 40 deletions(-)

diff --git a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java
index 8bf56bbb2..319987b82 100644
--- a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java
+++ b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/StatsService.java
@@ -35,14 +35,14 @@ public interface StatsService {
 	 * 
 	 * @return
 	 */
-	public Map<String, Integer> calculateSummaryStats();
+	public Map<String, Integer> getSummaryStats();
 
 	/**
 	 * Calculate usage count for all clients
 	 * 
 	 * @return a map of id of client object to number of approvals
 	 */
-	public Map<Long, Integer> calculateByClientId();
+	public Map<Long, Integer> getByClientId();
 
 	/**
 	 * Calculate the usage count for a single client
@@ -50,6 +50,11 @@ public interface StatsService {
 	 * @param id the id of the client to search on
 	 * @return
 	 */
-	public Integer countForClientId(Long id);
+	public Integer getCountForClientId(Long id);
+	
+	/**
+	 * Trigger the stats to be recalculated upon next update.
+	 */
+	public void resetCache();
 
 }
diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java
index 3d538cead..9a7219f73 100644
--- a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java
+++ b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java
@@ -37,6 +37,7 @@ import org.mitre.oauth2.service.SystemScopeService;
 import org.mitre.openid.connect.model.WhitelistedSite;
 import org.mitre.openid.connect.service.ApprovedSiteService;
 import org.mitre.openid.connect.service.BlacklistedSiteService;
+import org.mitre.openid.connect.service.StatsService;
 import org.mitre.openid.connect.service.WhitelistedSiteService;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -76,6 +77,9 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
 
 	@Autowired
 	private SystemScopeService scopeService;
+	
+	@Autowired
+	private StatsService statsService;
 
 	// map of sector URI -> list of redirect URIs
 	private LoadingCache<String, List<String>> sectorRedirects = CacheBuilder.newBuilder()
@@ -136,7 +140,11 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
 		// make sure a client doesn't get any special system scopes
 		client.setScope(scopeService.removeRestrictedScopes(client.getScope()));
 
-		return clientRepository.saveClient(client);
+		ClientDetailsEntity c = clientRepository.saveClient(client);
+		
+		statsService.resetCache();
+		
+		return c;
 	}
 
 	/**
@@ -192,6 +200,8 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
 		// take care of the client itself
 		clientRepository.deleteClient(client);
 
+		statsService.resetCache();
+		
 	}
 
 	/**
diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/web/OAuthConfirmationController.java b/openid-connect-server/src/main/java/org/mitre/oauth2/web/OAuthConfirmationController.java
index 6021332db..a47781d3f 100644
--- a/openid-connect-server/src/main/java/org/mitre/oauth2/web/OAuthConfirmationController.java
+++ b/openid-connect-server/src/main/java/org/mitre/oauth2/web/OAuthConfirmationController.java
@@ -171,7 +171,7 @@ public class OAuthConfirmationController {
 		model.put("claims", claimsForScopes);
 
 		// client stats
-		Integer count = statsService.countForClientId(client.getId());
+		Integer count = statsService.getCountForClientId(client.getId());
 		model.put("count", count);
 		
 		
diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java
index 358a1dcc8..c313f9e53 100644
--- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java
+++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java
@@ -26,6 +26,7 @@ import org.mitre.openid.connect.model.ApprovedSite;
 import org.mitre.openid.connect.model.WhitelistedSite;
 import org.mitre.openid.connect.repository.ApprovedSiteRepository;
 import org.mitre.openid.connect.service.ApprovedSiteService;
+import org.mitre.openid.connect.service.StatsService;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.beans.factory.annotation.Autowired;
@@ -53,6 +54,9 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
 
 	@Autowired
 	private OAuth2TokenRepository tokenRepository;
+	
+	@Autowired
+	private StatsService statsService;
 
 	@Override
 	public Collection<ApprovedSite> getAll() {
@@ -62,7 +66,9 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
 	@Override
 	@Transactional
 	public ApprovedSite save(ApprovedSite approvedSite) {
-		return approvedSiteRepository.save(approvedSite);
+		ApprovedSite a = approvedSiteRepository.save(approvedSite);
+		statsService.resetCache();
+		return a;
 	}
 
 	@Override
@@ -85,6 +91,8 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
 		}
 
 		approvedSiteRepository.remove(approvedSite);
+		
+		statsService.resetCache();
 	}
 
 	@Override
@@ -140,7 +148,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
 		Collection<ApprovedSite> approvedSites = approvedSiteRepository.getByClientId(client.getClientId());
 		if (approvedSites != null) {
 			for (ApprovedSite approvedSite : approvedSites) {
-				approvedSiteRepository.remove(approvedSite);
+				remove(approvedSite);
 			}
 		}
 	}
@@ -153,7 +161,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
 		Collection<ApprovedSite> expiredSites = getExpired();
 		if (expiredSites != null) {
 			for (ApprovedSite expired : expiredSites) {
-				approvedSiteRepository.remove(expired);
+				remove(expired);
 			}
 		}
 	}
diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java
index fc7e0e18a..d9dcc3173 100644
--- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java
+++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultStatsService.java
@@ -53,25 +53,32 @@ public class DefaultStatsService implements StatsService {
 	private ClientDetailsEntityService clientService;
 	
 	// stats cache
-	private Supplier<Map<String, Integer>> summaryCache = Suppliers.memoizeWithExpiration(new Supplier<Map<String, Integer>>() {
-		@Override
-		public Map<String, Integer> get() {
-			return computeSummaryStats();
-		}
+	private Supplier<Map<String, Integer>> summaryCache = createSummaryCache();
+	
+	private Supplier<Map<String, Integer>> createSummaryCache() {
+		return Suppliers.memoizeWithExpiration(new Supplier<Map<String, Integer>>() {
+			@Override
+			public Map<String, Integer> get() {
+				return computeSummaryStats();
+			}
+	
+		}, 10, TimeUnit.MINUTES);
+	}
 
-	}, 10, TimeUnit.MINUTES);
-
-	private Supplier<Map<Long, Integer>> byClientIdCache = Suppliers.memoizeWithExpiration(new Supplier<Map<Long, Integer>>() {
-
-		@Override
-		public Map<Long, Integer> get() {
-			return computeByClientId();
-		}
-		
-	}, 10, TimeUnit.MINUTES);
+	private Supplier<Map<Long, Integer>> byClientIdCache = createByClientIdCache();
+	
+	private Supplier<Map<Long, Integer>> createByClientIdCache() {
+		return Suppliers.memoizeWithExpiration(new Supplier<Map<Long, Integer>>() {
+			@Override
+			public Map<Long, Integer> get() {
+				return computeByClientId();
+			}
+			
+		}, 10, TimeUnit.MINUTES);
+	}
 	
 	@Override
-	public Map<String, Integer> calculateSummaryStats() {
+	public Map<String, Integer> getSummaryStats() {
 		return summaryCache.get();
 	}
 	
@@ -100,7 +107,7 @@ public class DefaultStatsService implements StatsService {
 	 * @see org.mitre.openid.connect.service.StatsService#calculateByClientId()
 	 */
 	@Override
-	public Map<Long, Integer> calculateByClientId() {
+	public Map<Long, Integer> getByClientId() {
 		return byClientIdCache.get();
 	}
 	
@@ -126,9 +133,9 @@ public class DefaultStatsService implements StatsService {
 	 * @see org.mitre.openid.connect.service.StatsService#countForClientId(java.lang.String)
 	 */
 	@Override
-	public Integer countForClientId(Long id) {
+	public Integer getCountForClientId(Long id) {
 
-		Map<Long, Integer> counts = calculateByClientId();
+		Map<Long, Integer> counts = getByClientId();
 		return counts.get(id);
 
 	}
@@ -147,4 +154,13 @@ public class DefaultStatsService implements StatsService {
 		return counts;
 	}
 
+	/**
+	 * Reset both stats caches on a trigger (before the timer runs out). Resets the timers.
+	 */
+	@Override
+	public void resetCache() {
+		summaryCache = createSummaryCache();
+		byClientIdCache = createByClientIdCache();
+	}
+	
 }
diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java
index edc9a07c7..78f1285f5 100644
--- a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java
+++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/ManagerController.java
@@ -38,7 +38,7 @@ public class ManagerController {
 	@RequestMapping({"", "home", "index"})
 	public String showHomePage(ModelMap m) {
 
-		Map<String, Integer> summary = statsService.calculateSummaryStats();
+		Map<String, Integer> summary = statsService.getSummaryStats();
 
 		m.put("statsSummary", summary);
 		return "home";
@@ -53,7 +53,7 @@ public class ManagerController {
 	@RequestMapping({"stats", "stats/"})
 	public String showStatsPage(ModelMap m) {
 
-		Map<String, Integer> summary = statsService.calculateSummaryStats();
+		Map<String, Integer> summary = statsService.getSummaryStats();
 
 		m.put("statsSummary", summary);
 		return "stats";
diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java
index 5665d9afd..8de307ea8 100644
--- a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java
+++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/StatsAPI.java
@@ -37,7 +37,7 @@ public class StatsAPI {
 	@RequestMapping(value = "summary", produces = "application/json")
 	public String statsSummary(ModelMap m) {
 
-		Map<String, Integer> e = statsService.calculateSummaryStats();
+		Map<String, Integer> e = statsService.getSummaryStats();
 
 		m.put("entity", e);
 
@@ -47,7 +47,7 @@ public class StatsAPI {
 
 	@RequestMapping(value = "byclientid", produces = "application/json")
 	public String statsByClient(ModelMap m) {
-		Map<Long, Integer> e = statsService.calculateByClientId();
+		Map<Long, Integer> e = statsService.getByClientId();
 
 		m.put("entity", e);
 
@@ -56,7 +56,7 @@ public class StatsAPI {
 
 	@RequestMapping(value = "byclientid/{id}", produces = "application/json")
 	public String statsByClientId(@PathVariable("id") Long id, ModelMap m) {
-		Integer e = statsService.countForClientId(id);
+		Integer e = statsService.getCountForClientId(id);
 
 		m.put("entity", e);
 
diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java
index d75f772e0..8e0338f53 100644
--- a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java
+++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestDefaultStatsService.java
@@ -59,6 +59,8 @@ public class TestDefaultStatsService {
 	private ApprovedSite ap2 = Mockito.mock(ApprovedSite.class);
 	private ApprovedSite ap3 = Mockito.mock(ApprovedSite.class);
 	private ApprovedSite ap4 = Mockito.mock(ApprovedSite.class);
+	private ApprovedSite ap5 = Mockito.mock(ApprovedSite.class);
+	private ApprovedSite ap6 = Mockito.mock(ApprovedSite.class);
 
 	private ClientDetailsEntity client1 = Mockito.mock(ClientDetailsEntity.class);
 	private ClientDetailsEntity client2 = Mockito.mock(ClientDetailsEntity.class);
@@ -95,6 +97,12 @@ public class TestDefaultStatsService {
 		Mockito.when(ap4.getUserId()).thenReturn(userId2);
 		Mockito.when(ap4.getClientId()).thenReturn(clientId3);
 
+		Mockito.when(ap5.getUserId()).thenReturn(userId2);
+		Mockito.when(ap5.getClientId()).thenReturn(clientId1);
+		
+		Mockito.when(ap6.getUserId()).thenReturn(userId1);
+		Mockito.when(ap6.getClientId()).thenReturn(clientId4);
+
 		Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4));
 
 		Mockito.when(client1.getId()).thenReturn(1L);
@@ -114,7 +122,7 @@ public class TestDefaultStatsService {
 
 		Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet<ApprovedSite>());
 
-		Map<String, Integer> stats = service.calculateSummaryStats();
+		Map<String, Integer> stats = service.getSummaryStats();
 
 		assertThat(stats.get("approvalCount"), is(0));
 		assertThat(stats.get("userCount"), is(0));
@@ -123,7 +131,7 @@ public class TestDefaultStatsService {
 
 	@Test
 	public void calculateSummaryStats() {
-		Map<String, Integer> stats = service.calculateSummaryStats();
+		Map<String, Integer> stats = service.getSummaryStats();
 
 		assertThat(stats.get("approvalCount"), is(4));
 		assertThat(stats.get("userCount"), is(2));
@@ -135,7 +143,7 @@ public class TestDefaultStatsService {
 
 		Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet<ApprovedSite>());
 
-		Map<Long, Integer> stats = service.calculateByClientId();
+		Map<Long, Integer> stats = service.getByClientId();
 
 		assertThat(stats.get(1L), is(0));
 		assertThat(stats.get(2L), is(0));
@@ -146,7 +154,7 @@ public class TestDefaultStatsService {
 	@Test
 	public void calculateByClientId() {
 
-		Map<Long, Integer> stats = service.calculateByClientId();
+		Map<Long, Integer> stats = service.getByClientId();
 
 		assertThat(stats.get(1L), is(2));
 		assertThat(stats.get(2L), is(1));
@@ -157,9 +165,38 @@ public class TestDefaultStatsService {
 	@Test
 	public void countForClientId() {
 
-		assertThat(service.countForClientId(1L), is(2));
-		assertThat(service.countForClientId(2L), is(1));
-		assertThat(service.countForClientId(3L), is(1));
-		assertThat(service.countForClientId(4L), is(0));
+		assertThat(service.getCountForClientId(1L), is(2));
+		assertThat(service.getCountForClientId(2L), is(1));
+		assertThat(service.getCountForClientId(3L), is(1));
+		assertThat(service.getCountForClientId(4L), is(0));
+	}
+	
+	@Test
+	public void cacheAndReset() {
+		
+		Map<String, Integer> stats = service.getSummaryStats();
+
+		assertThat(stats.get("approvalCount"), is(4));
+		assertThat(stats.get("userCount"), is(2));
+		assertThat(stats.get("clientCount"), is(3));
+
+		Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4, ap5, ap6));
+		
+		Map<String, Integer> stats2 = service.getSummaryStats();
+		
+		// cache should remain the same due to memoized functions
+		assertThat(stats2.get("approvalCount"), is(4));
+		assertThat(stats2.get("userCount"), is(2));
+		assertThat(stats2.get("clientCount"), is(3));		
+		
+		// reset the cache and make sure the count goes up
+		service.resetCache();
+		
+		Map<String, Integer> stats3 = service.getSummaryStats();
+		
+		assertThat(stats3.get("approvalCount"), is(6));
+		assertThat(stats3.get("userCount"), is(2));
+		assertThat(stats3.get("clientCount"), is(4));		
+		
 	}
 }