RateLimitFilter.java
package se.jobtechdev.personaldatagateway.api.ratelimit;
import jakarta.servlet.*;
import jakarta.servlet.http.HttpFilter;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import java.io.IOException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.Order;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.stereotype.Component;
@Component
@Order(10)
public class RateLimitFilter extends HttpFilter {
private final transient RateLimitProfiles rateLimitProfiles;
@Autowired
public RateLimitFilter(RateLimitProfiles rateLimitProfiles) {
this.rateLimitProfiles = rateLimitProfiles;
}
public record ProfileAndKey(String profile, String key) {}
protected static ProfileAndKey extractProfileAndKeyFromRequest(HttpServletRequest request) {
HttpSession httpSession = request.getSession(false);
final var defaultProfileAndKey = new ProfileAndKey("default", request.getRemoteAddr());
if (null == httpSession) {
return defaultProfileAndKey;
}
final var sci = (SecurityContextImpl) httpSession.getAttribute("SPRING_SECURITY_CONTEXT");
final var authentication = sci.getAuthentication();
final var clientId = (String) authentication.getPrincipal();
final var authorities = authentication.getAuthorities();
return authorities.stream()
.findFirst()
.map(GrantedAuthority::getAuthority)
.map(profile -> new ProfileAndKey(profile, clientId))
.orElse(defaultProfileAndKey);
}
@Override
public void doFilter(
HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
final var profileAndKey = extractProfileAndKeyFromRequest(request);
final var rateLimitProfile = rateLimitProfiles.getRateLimitProfile(profileAndKey.profile());
final var quota = rateLimitProfile.consume(profileAndKey.key());
response.setHeader("RateLimit-Policy", quota.rateLimitPolicy());
response.setHeader("RateLimit", quota.rateLimit());
if (quota.exceeded()) {
response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
response.setContentType("text/plain");
response.getWriter().append(HttpStatus.TOO_MANY_REQUESTS.getReasonPhrase());
} else {
filterChain.doFilter(request, response);
}
}
}