updated stats service to have a resettable cache triggered by other service events

Conflicts:
	openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ClientDetailsEntityService.java
	openid-connect-server/src/main/java/org/mitre/oauth2/web/OAuthConfirmationController.java
	openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java
pull/661/merge
Justin Richer 2014-04-16 21:39:37 -04:00
parent 7e4c153ba2
commit 3b27fc61da
7 changed files with 118 additions and 39 deletions

View File

@ -35,14 +35,14 @@ public interface StatsService {
* *
* @return * @return
*/ */
public Map<String, Integer> calculateSummaryStats(); public Map<String, Integer> getSummaryStats();
/** /**
* Calculate usage count for all clients * Calculate usage count for all clients
* *
* @return a map of id of client object to number of approvals * @return a map of id of client object to number of approvals
*/ */
public Map<Long, Integer> calculateByClientId(); public Map<Long, Integer> getByClientId();
/** /**
* Calculate the usage count for a single client * Calculate the usage count for a single client
@ -50,6 +50,11 @@ public interface StatsService {
* @param id the id of the client to search on * @param id the id of the client to search on
* @return * @return
*/ */
public Integer countForClientId(Long id); public Integer getCountForClientId(Long id);
/**
* Trigger the stats to be recalculated upon next update.
*/
public void resetCache();
} }

View File

@ -30,6 +30,7 @@ import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.openid.connect.model.WhitelistedSite; import org.mitre.openid.connect.model.WhitelistedSite;
import org.mitre.openid.connect.service.ApprovedSiteService; import org.mitre.openid.connect.service.ApprovedSiteService;
import org.mitre.openid.connect.service.BlacklistedSiteService; import org.mitre.openid.connect.service.BlacklistedSiteService;
import org.mitre.openid.connect.service.StatsService;
import org.mitre.openid.connect.service.WhitelistedSiteService; import org.mitre.openid.connect.service.WhitelistedSiteService;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException; import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
@ -56,6 +57,8 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
@Autowired @Autowired
private BlacklistedSiteService blacklistedSiteService; private BlacklistedSiteService blacklistedSiteService;
@Autowired
private StatsService statsService;
@Override @Override
public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) { public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) {
@ -87,7 +90,11 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
// timestamp this to right now // timestamp this to right now
client.setCreatedAt(new Date()); client.setCreatedAt(new Date());
return clientRepository.saveClient(client); ClientDetailsEntity c = clientRepository.saveClient(client);
statsService.resetCache();
return c;
} }
/** /**
@ -143,6 +150,8 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
// take care of the client itself // take care of the client itself
clientRepository.deleteClient(client); clientRepository.deleteClient(client);
statsService.resetCache();
} }
/** /**

View File

@ -20,10 +20,12 @@ import java.util.Collection;
import java.util.Date; import java.util.Date;
import java.util.Set; import java.util.Set;
import org.mitre.oauth2.repository.OAuth2TokenRepository;
import org.mitre.openid.connect.model.ApprovedSite; import org.mitre.openid.connect.model.ApprovedSite;
import org.mitre.openid.connect.model.WhitelistedSite; import org.mitre.openid.connect.model.WhitelistedSite;
import org.mitre.openid.connect.repository.ApprovedSiteRepository; import org.mitre.openid.connect.repository.ApprovedSiteRepository;
import org.mitre.openid.connect.service.ApprovedSiteService; import org.mitre.openid.connect.service.ApprovedSiteService;
import org.mitre.openid.connect.service.StatsService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@ -49,6 +51,12 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
@Autowired @Autowired
private ApprovedSiteRepository approvedSiteRepository; private ApprovedSiteRepository approvedSiteRepository;
@Autowired
private OAuth2TokenRepository tokenRepository;
@Autowired
private StatsService statsService;
@Override @Override
public Collection<ApprovedSite> getAll() { public Collection<ApprovedSite> getAll() {
return approvedSiteRepository.getAll(); return approvedSiteRepository.getAll();
@ -57,7 +65,9 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
@Override @Override
@Transactional @Transactional
public ApprovedSite save(ApprovedSite approvedSite) { public ApprovedSite save(ApprovedSite approvedSite) {
return approvedSiteRepository.save(approvedSite); ApprovedSite a = approvedSiteRepository.save(approvedSite);
statsService.resetCache();
return a;
} }
@Override @Override
@ -69,6 +79,8 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
@Transactional @Transactional
public void remove(ApprovedSite approvedSite) { public void remove(ApprovedSite approvedSite) {
approvedSiteRepository.remove(approvedSite); approvedSiteRepository.remove(approvedSite);
statsService.resetCache();
} }
@Override @Override
@ -124,7 +136,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
Collection<ApprovedSite> approvedSites = approvedSiteRepository.getByClientId(client.getClientId()); Collection<ApprovedSite> approvedSites = approvedSiteRepository.getByClientId(client.getClientId());
if (approvedSites != null) { if (approvedSites != null) {
for (ApprovedSite approvedSite : approvedSites) { for (ApprovedSite approvedSite : approvedSites) {
approvedSiteRepository.remove(approvedSite); remove(approvedSite);
} }
} }
} }
@ -137,7 +149,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
Collection<ApprovedSite> expiredSites = getExpired(); Collection<ApprovedSite> expiredSites = getExpired();
if (expiredSites != null) { if (expiredSites != null) {
for (ApprovedSite expired : expiredSites) { for (ApprovedSite expired : expiredSites) {
approvedSiteRepository.remove(expired); remove(expired);
} }
} }
} }

View File

@ -53,25 +53,32 @@ public class DefaultStatsService implements StatsService {
private ClientDetailsEntityService clientService; private ClientDetailsEntityService clientService;
// stats cache // stats cache
private Supplier<Map<String, Integer>> summaryCache = Suppliers.memoizeWithExpiration(new Supplier<Map<String, Integer>>() { private Supplier<Map<String, Integer>> summaryCache = createSummaryCache();
@Override
public Map<String, Integer> get() { private Supplier<Map<String, Integer>> createSummaryCache() {
return computeSummaryStats(); return Suppliers.memoizeWithExpiration(new Supplier<Map<String, Integer>>() {
} @Override
public Map<String, Integer> get() {
return computeSummaryStats();
}
}, 10, TimeUnit.MINUTES);
}
}, 10, TimeUnit.MINUTES); private Supplier<Map<Long, Integer>> byClientIdCache = createByClientIdCache();
private Supplier<Map<Long, Integer>> byClientIdCache = Suppliers.memoizeWithExpiration(new Supplier<Map<Long, Integer>>() { private Supplier<Map<Long, Integer>> createByClientIdCache() {
return Suppliers.memoizeWithExpiration(new Supplier<Map<Long, Integer>>() {
@Override @Override
public Map<Long, Integer> get() { public Map<Long, Integer> get() {
return computeByClientId(); return computeByClientId();
} }
}, 10, TimeUnit.MINUTES); }, 10, TimeUnit.MINUTES);
}
@Override @Override
public Map<String, Integer> calculateSummaryStats() { public Map<String, Integer> getSummaryStats() {
return summaryCache.get(); return summaryCache.get();
} }
@ -100,7 +107,7 @@ public class DefaultStatsService implements StatsService {
* @see org.mitre.openid.connect.service.StatsService#calculateByClientId() * @see org.mitre.openid.connect.service.StatsService#calculateByClientId()
*/ */
@Override @Override
public Map<Long, Integer> calculateByClientId() { public Map<Long, Integer> getByClientId() {
return byClientIdCache.get(); return byClientIdCache.get();
} }
@ -126,9 +133,9 @@ public class DefaultStatsService implements StatsService {
* @see org.mitre.openid.connect.service.StatsService#countForClientId(java.lang.String) * @see org.mitre.openid.connect.service.StatsService#countForClientId(java.lang.String)
*/ */
@Override @Override
public Integer countForClientId(Long id) { public Integer getCountForClientId(Long id) {
Map<Long, Integer> counts = calculateByClientId(); Map<Long, Integer> counts = getByClientId();
return counts.get(id); return counts.get(id);
} }
@ -147,4 +154,13 @@ public class DefaultStatsService implements StatsService {
return counts; return counts;
} }
/**
* Reset both stats caches on a trigger (before the timer runs out). Resets the timers.
*/
@Override
public void resetCache() {
summaryCache = createSummaryCache();
byClientIdCache = createByClientIdCache();
}
} }

View File

@ -38,7 +38,7 @@ public class ManagerController {
@RequestMapping({"", "home", "index"}) @RequestMapping({"", "home", "index"})
public String showHomePage(ModelMap m) { public String showHomePage(ModelMap m) {
Map<String, Integer> summary = statsService.calculateSummaryStats(); Map<String, Integer> summary = statsService.getSummaryStats();
m.put("statsSummary", summary); m.put("statsSummary", summary);
return "home"; return "home";
@ -53,7 +53,7 @@ public class ManagerController {
@RequestMapping({"stats", "stats/"}) @RequestMapping({"stats", "stats/"})
public String showStatsPage(ModelMap m) { public String showStatsPage(ModelMap m) {
Map<String, Integer> summary = statsService.calculateSummaryStats(); Map<String, Integer> summary = statsService.getSummaryStats();
m.put("statsSummary", summary); m.put("statsSummary", summary);
return "stats"; return "stats";

View File

@ -37,7 +37,7 @@ public class StatsAPI {
@RequestMapping(value = "summary", produces = "application/json") @RequestMapping(value = "summary", produces = "application/json")
public String statsSummary(ModelMap m) { public String statsSummary(ModelMap m) {
Map<String, Integer> e = statsService.calculateSummaryStats(); Map<String, Integer> e = statsService.getSummaryStats();
m.put("entity", e); m.put("entity", e);
@ -47,7 +47,7 @@ public class StatsAPI {
@RequestMapping(value = "byclientid", produces = "application/json") @RequestMapping(value = "byclientid", produces = "application/json")
public String statsByClient(ModelMap m) { public String statsByClient(ModelMap m) {
Map<Long, Integer> e = statsService.calculateByClientId(); Map<Long, Integer> e = statsService.getByClientId();
m.put("entity", e); m.put("entity", e);
@ -56,7 +56,7 @@ public class StatsAPI {
@RequestMapping(value = "byclientid/{id}", produces = "application/json") @RequestMapping(value = "byclientid/{id}", produces = "application/json")
public String statsByClientId(@PathVariable("id") Long id, ModelMap m) { public String statsByClientId(@PathVariable("id") Long id, ModelMap m) {
Integer e = statsService.countForClientId(id); Integer e = statsService.getCountForClientId(id);
m.put("entity", e); m.put("entity", e);

View File

@ -59,6 +59,8 @@ public class TestDefaultStatsService {
private ApprovedSite ap2 = Mockito.mock(ApprovedSite.class); private ApprovedSite ap2 = Mockito.mock(ApprovedSite.class);
private ApprovedSite ap3 = Mockito.mock(ApprovedSite.class); private ApprovedSite ap3 = Mockito.mock(ApprovedSite.class);
private ApprovedSite ap4 = Mockito.mock(ApprovedSite.class); private ApprovedSite ap4 = Mockito.mock(ApprovedSite.class);
private ApprovedSite ap5 = Mockito.mock(ApprovedSite.class);
private ApprovedSite ap6 = Mockito.mock(ApprovedSite.class);
private ClientDetailsEntity client1 = Mockito.mock(ClientDetailsEntity.class); private ClientDetailsEntity client1 = Mockito.mock(ClientDetailsEntity.class);
private ClientDetailsEntity client2 = Mockito.mock(ClientDetailsEntity.class); private ClientDetailsEntity client2 = Mockito.mock(ClientDetailsEntity.class);
@ -95,6 +97,12 @@ public class TestDefaultStatsService {
Mockito.when(ap4.getUserId()).thenReturn(userId2); Mockito.when(ap4.getUserId()).thenReturn(userId2);
Mockito.when(ap4.getClientId()).thenReturn(clientId3); Mockito.when(ap4.getClientId()).thenReturn(clientId3);
Mockito.when(ap5.getUserId()).thenReturn(userId2);
Mockito.when(ap5.getClientId()).thenReturn(clientId1);
Mockito.when(ap6.getUserId()).thenReturn(userId1);
Mockito.when(ap6.getClientId()).thenReturn(clientId4);
Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4)); Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4));
Mockito.when(client1.getId()).thenReturn(1L); Mockito.when(client1.getId()).thenReturn(1L);
@ -114,7 +122,7 @@ public class TestDefaultStatsService {
Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet<ApprovedSite>()); Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet<ApprovedSite>());
Map<String, Integer> stats = service.calculateSummaryStats(); Map<String, Integer> stats = service.getSummaryStats();
assertThat(stats.get("approvalCount"), is(0)); assertThat(stats.get("approvalCount"), is(0));
assertThat(stats.get("userCount"), is(0)); assertThat(stats.get("userCount"), is(0));
@ -123,7 +131,7 @@ public class TestDefaultStatsService {
@Test @Test
public void calculateSummaryStats() { public void calculateSummaryStats() {
Map<String, Integer> stats = service.calculateSummaryStats(); Map<String, Integer> stats = service.getSummaryStats();
assertThat(stats.get("approvalCount"), is(4)); assertThat(stats.get("approvalCount"), is(4));
assertThat(stats.get("userCount"), is(2)); assertThat(stats.get("userCount"), is(2));
@ -135,7 +143,7 @@ public class TestDefaultStatsService {
Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet<ApprovedSite>()); Mockito.when(approvedSiteService.getAll()).thenReturn(new HashSet<ApprovedSite>());
Map<Long, Integer> stats = service.calculateByClientId(); Map<Long, Integer> stats = service.getByClientId();
assertThat(stats.get(1L), is(0)); assertThat(stats.get(1L), is(0));
assertThat(stats.get(2L), is(0)); assertThat(stats.get(2L), is(0));
@ -146,7 +154,7 @@ public class TestDefaultStatsService {
@Test @Test
public void calculateByClientId() { public void calculateByClientId() {
Map<Long, Integer> stats = service.calculateByClientId(); Map<Long, Integer> stats = service.getByClientId();
assertThat(stats.get(1L), is(2)); assertThat(stats.get(1L), is(2));
assertThat(stats.get(2L), is(1)); assertThat(stats.get(2L), is(1));
@ -157,9 +165,38 @@ public class TestDefaultStatsService {
@Test @Test
public void countForClientId() { public void countForClientId() {
assertThat(service.countForClientId(1L), is(2)); assertThat(service.getCountForClientId(1L), is(2));
assertThat(service.countForClientId(2L), is(1)); assertThat(service.getCountForClientId(2L), is(1));
assertThat(service.countForClientId(3L), is(1)); assertThat(service.getCountForClientId(3L), is(1));
assertThat(service.countForClientId(4L), is(0)); assertThat(service.getCountForClientId(4L), is(0));
}
@Test
public void cacheAndReset() {
Map<String, Integer> stats = service.getSummaryStats();
assertThat(stats.get("approvalCount"), is(4));
assertThat(stats.get("userCount"), is(2));
assertThat(stats.get("clientCount"), is(3));
Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4, ap5, ap6));
Map<String, Integer> stats2 = service.getSummaryStats();
// cache should remain the same due to memoized functions
assertThat(stats2.get("approvalCount"), is(4));
assertThat(stats2.get("userCount"), is(2));
assertThat(stats2.get("clientCount"), is(3));
// reset the cache and make sure the count goes up
service.resetCache();
Map<String, Integer> stats3 = service.getSummaryStats();
assertThat(stats3.get("approvalCount"), is(6));
assertThat(stats3.get("userCount"), is(2));
assertThat(stats3.get("clientCount"), is(4));
} }
} }