RateLimitFilter.java

package se.jobtechdev.personaldatagateway.api.ratelimit;

import jakarta.servlet.*;
import jakarta.servlet.http.HttpFilter;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import java.io.IOException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.Order;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.stereotype.Component;

@Component
@Order(10)
public class RateLimitFilter extends HttpFilter {
  private final transient RateLimitProfiles rateLimitProfiles;

  @Autowired
  public RateLimitFilter(RateLimitProfiles rateLimitProfiles) {
    this.rateLimitProfiles = rateLimitProfiles;
  }

  public record ProfileAndKey(String profile, String key) {}

  protected static ProfileAndKey extractProfileAndKeyFromRequest(HttpServletRequest request) {
    HttpSession httpSession = request.getSession(false);

    final var defaultProfileAndKey = new ProfileAndKey("default", request.getRemoteAddr());

    if (null == httpSession) {
      return defaultProfileAndKey;
    }

    final var sci = (SecurityContextImpl) httpSession.getAttribute("SPRING_SECURITY_CONTEXT");

    final var authentication = sci.getAuthentication();

    final var clientId = (String) authentication.getPrincipal();

    final var authorities = authentication.getAuthorities();

    return authorities.stream()
        .findFirst()
        .map(GrantedAuthority::getAuthority)
        .map(profile -> new ProfileAndKey(profile, clientId))
        .orElse(defaultProfileAndKey);
  }

  @Override
  public void doFilter(
      HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
      throws ServletException, IOException {

    final var profileAndKey = extractProfileAndKeyFromRequest(request);

    final var rateLimitProfile = rateLimitProfiles.getRateLimitProfile(profileAndKey.profile());
    final var quota = rateLimitProfile.consume(profileAndKey.key());

    response.setHeader("RateLimit-Policy", quota.rateLimitPolicy());
    response.setHeader("RateLimit", quota.rateLimit());

    if (quota.exceeded()) {
      response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
      response.setContentType("text/plain");
      response.getWriter().append(HttpStatus.TOO_MANY_REQUESTS.getReasonPhrase());
    } else {
      filterChain.doFilter(request, response);
    }
  }
}