Rate Limiting Spring Reactive Web APIs (Bucket4J)

Based on a couple of other good guides out there – see reference [1]. Updated for Spring Reactive Webflux

Goal

Rate limiting within Reactive Spring Web can be achieved in a number of ways so to expand on a couple of different scenarios we define the following goals:

  • Rate limit all API calls by session
  • Rate limit all API calls by Source IP address

Note that in both goals we are interested in rate limiting all API calls (i.e. not a subset of URLs from a given API

Dependencies / Libraries

Getting started with these libraries is very straightforward. First generate the required Spring Boot project from https://start.spring.io/ and add an additional dependency in maven for Bucket4J

Approach to Solution

Both our goals state that we need to rate limit all API calls. In Spring, when an operation is required on all API calls rather than just a subset of URLs, Web Filters should spring to mind. A Web Filter needs to override the “filter” method:

@Component
class ApiFilter() : WebFilter {

   override fun filter(serverWebExchange: ServerWebExchange,
                        webFilterChain: WebFilterChain): Mono<Void> {
   // TODO
}

Note the filter method must return a Mono<Void> and accepts two arguments:

  • serverWebExchange: allows you to access the HTTP request sent to the server as well as the HTTP response that will eventually be returned from the server
  • webFilterChain: filters are organized in an abstract “chain”, with each filter processing the request and handing it off to the next filter in line

The filter method can do one of two things:

  • Process the incoming request, and terminate the exchange by immediately sending a response back to the client. In this case the filter function would modify the response (via serverWebExchange.response) and return a Mono.empty()
  • Process the incoming request, and pass along the request to the next filter in the chain. In this case the function would:
return webFilterChain.filter(serverWebExchange)

With this in mind, we can now build a filter which enforces our API rate limiting

Rate Limiting By Source IP

This is the simpler goal. We need to track the source IP address of each request, retrieve it’s associated bucket, and return a response accordingly. We need an object to store the source IP bucket, in this case we use a hash map. The key would be the source IP and the value would be the bucket itself:

object RateLimitingCache {
    val byIP = HashMap<String, Bucket>()
}

If the Source IP is a new one, we need to create a bucket to store in the above hash map:

private fun createIpRateLimitBucket() : Bucket {
    val limit = Bandwidth.simple(500, Duration.ofMinutes(10))
    return Bucket4j.builder().addLimit(limit).build()
}

Last, we need to modify our filter function to enforce our limits:

override fun filter(serverWebExchange: ServerWebExchange,
                        webFilterChain: WebFilterChain): Mono<Void> {
                val sourceIP = serverWebExchange.request.remoteAddress!!.address.hostAddress
                if (RateLimitingCache.byIP.containsKey(sourceIP)){
                    println("Available IP tokens left: ${RateLimitingCache.byIP[sourceIP]!!.availableTokens}")
                    if (!RateLimitingCache.byIP[sourceIP]!!.tryConsume(1)){
                        serverWebExchange.response.statusCode=HttpStatus.BANDWIDTH_LIMIT_EXCEEDED
                        return  Mono.empty()
                    }
                } else {
                    println("Creating new IP bucket...")
                    RateLimitingCache.byIP[sourceIP] = createIpRateLimitBucket()
                }

                return webFilterChain.filter(serverWebExchange)
}

Notes:

  • We use the serverWebExchange.request to access the client IP address (note you may need to change this if you use a reverse proxy)
  • If the HashMap contains a bucket for this IP, we consume a token, and if the IP doesn’t have any tokens left, we modify the response using serverWebExchange.response.satusCode and return a Mono.empty()
  • Otherwise, we simply pass along the request to the next filter in the chain via webFilterChain.filter()

Rate Limiting By Session

This goal is a bit more involved. A HTTP user sessions is typically stored in a cookie.This means there’s no need to a HashMap to store the bucket as we did above – it can be stored in the user session. Creating the session bucket remains very similar to the above:

private fun createSessionRateLimitBucket() : Bucket {
        val limit = Bandwidth.simple(100, Duration.ofMinutes(10))
        return Bucket4j.builder().addLimit(limit).build()
    }

Enforcing the rate limit in the filter is a bit more involved since we need to extract the session from the serverWebExchange using reactive-style programming, without relying on .block() – which is not allowed in Spring Webflux.

The code to achieve this is as follows (relying on a screenshot from an IDE since it contains labels for better understanding):

The code is pretty well commented. Note how by using flatMap we return a Mono<Void> as required by the ‘filter’ method.

Complete WebFilter code

object RateLimitingCache {
    val byIP = HashMap<String, Bucket>()
}

@Component
class ApiFilter() : WebFilter {

    private fun createSessionRateLimitBucket() : Bucket {
        val limit = Bandwidth.simple(100, Duration.ofMinutes(10))
        return Bucket4j.builder().addLimit(limit).build()
    }

    private fun createIpRateLimitBucket() : Bucket {
        val limit = Bandwidth.simple(500, Duration.ofMinutes(10))
        return Bucket4j.builder().addLimit(limit).build()
    }


    override fun filter(serverWebExchange: ServerWebExchange,
                        webFilterChain: WebFilterChain): Mono<Void> {

                val sourceIP = serverWebExchange.request.remoteAddress!!.address.hostAddress
                if (RateLimitingCache.byIP.containsKey(sourceIP)){
                    println("Available IP tokens left: ${RateLimitingCache.byIP[sourceIP]!!.availableTokens}")
                    if (!RateLimitingCache.byIP[sourceIP]!!.tryConsume(1)){
                        serverWebExchange.response.statusCode=HttpStatus.BANDWIDTH_LIMIT_EXCEEDED
                        return  Mono.empty()
                    }
                } else {
                    println("Creating new IP bucket...")
                    RateLimitingCache.byIP[sourceIP] = createIpRateLimitBucket()
                }

                return serverWebExchange.session
                        // use flatmap to extract the WebSession object from serverWebExchange
                        .flatMap { webSession ->
                            // check if a bucket already exists for this session
                            if (webSession.attributes.containsKey("bucket")){
                                // if it does - extract the bucket from the session
                                val bucket = webSession.attributes["bucket"] as Bucket
                                // consume a token
                                if (bucket.tryConsume(1)){
                                    // if allowed - i.e. not over the allocated rate,
                                    // then pass request on to the next filter in the chain
                                    println("Available session tokens left : ${bucket.availableTokens}")
                                    webFilterChain.filter(serverWebExchange)
                                } else {
                                    // if not allowed then modify response code and immediately return to client
                                    serverWebExchange.response.statusCode=HttpStatus.BANDWIDTH_LIMIT_EXCEEDED
                                    Mono.empty()
                                }
                            } else {
                                // if bucket does not exist create a new one
                                val bucket = createSessionRateLimitBucket()
                                println("Creating new session bucket...")
                                // save bucket to session
                                webSession.attributes["bucket"]=bucket
                                bucket.tryConsume(1)
                                // pass on the request to the next filter in the chain
                                webFilterChain.filter(serverWebExchange)
                            }
                        }

    }
}

References

[1] https://www.baeldung.com/spring-bucket4j