diff --git a/README.rst b/README.rst index 8bcd748..3881dd1 100644 --- a/README.rst +++ b/README.rst @@ -52,6 +52,17 @@ Example: 600 requests per minute except TooManyRequests: return '429 Too Many Requests' +Example: 1 request per 50 milliseconds + +.. code-block:: python + + from redis_rate_limit import RateLimit, TooManyRequests, TimeUnit + try: + with RateLimit(resource='users_list', client='192.168.0.10', max_requests=1, expire=50, time_unit=TimeUnit.MILLISECOND): + return '200 OK' + except TooManyRequests: + return '429 Too Many Requests' + Example: 100 requests per hour .. code-block:: python diff --git a/redis_rate_limit/__init__.py b/redis_rate_limit/__init__.py index b08e96d..1750713 100644 --- a/redis_rate_limit/__init__.py +++ b/redis_rate_limit/__init__.py @@ -7,19 +7,26 @@ from redis.exceptions import NoScriptError from redis import Redis, ConnectionPool +from enum import Enum + + +class TimeUnit(Enum): + SECOND = "second" + MILLISECOND = "millisecond" + # Adapted from http://redis.io/commands/incr#pattern-rate-limiter-2 INCREMENT_SCRIPT = b""" local current current = tonumber(redis.call("incrby", KEYS[1], ARGV[2])) if current == tonumber(ARGV[2]) then - redis.call("expire", KEYS[1], ARGV[1]) + redis.call("PEXPIRE", KEYS[1], ARGV[1]) end return current """ INCREMENT_SCRIPT_HASH = sha1(INCREMENT_SCRIPT).hexdigest() -REDIS_POOL = ConnectionPool(host='127.0.0.1', port=6379, db=0) +REDIS_POOL = ConnectionPool(host="127.0.0.1", port=6379, db=0) class RedisVersionNotSupported(Exception): @@ -27,6 +34,7 @@ class RedisVersionNotSupported(Exception): Rate Limit depends on Redis’ commands EVALSHA and EVAL which are only available since the version 2.6.0 of the database. """ + pass @@ -35,6 +43,7 @@ class TooManyRequests(Exception): Occurs when the maximum number of requests is reached for a given resource of an specific user. """ + pass @@ -43,7 +52,16 @@ class RateLimit(object): This class offers an abstraction of a Rate Limit algorithm implemented on top of Redis >= 2.6.0. """ - def __init__(self, resource, client, max_requests, expire=None, redis_pool=REDIS_POOL): + + def __init__( + self, + resource, + client, + max_requests, + expire=None, + redis_pool=REDIS_POOL, + time_unit: TimeUnit = TimeUnit.SECOND, + ): """ Class initialization method checks if the Rate Limit algorithm is actually supported by the installed Redis version and sets some @@ -65,6 +83,7 @@ def __init__(self, resource, client, max_requests, expire=None, redis_pool=REDIS self._rate_limit_key = "rate_limit:{0}_{1}".format(resource, client) self._max_requests = max_requests self._expire = expire or 1 + self._expire = expire * 1000 if time_unit == TimeUnit.SECOND else expire def __call__(self, func): """ @@ -130,21 +149,32 @@ def increment_usage(self, increment_by=1): :return: integer: current usage """ if increment_by > self._max_requests: - raise ValueError('increment_by {increment_by} overflows ' - 'max_requests of {max_requests}' - .format(increment_by=increment_by, - max_requests=self._max_requests)) + raise ValueError( + "increment_by {increment_by} overflows " + "max_requests of {max_requests}".format( + increment_by=increment_by, max_requests=self._max_requests + ) + ) elif increment_by <= 0: - raise ValueError('{increment_by} is not a valid increment, ' - 'should be greater than or equal to zero.' - .format(increment_by=increment_by)) + raise ValueError( + "{increment_by} is not a valid increment, " + "should be greater than or equal to zero.".format( + increment_by=increment_by + ) + ) try: current_usage = self._redis.evalsha( - INCREMENT_SCRIPT_HASH, 1, self._rate_limit_key, self._expire, increment_by) + INCREMENT_SCRIPT_HASH, + 1, + self._rate_limit_key, + self._expire, + increment_by, + ) except NoScriptError: current_usage = self._redis.eval( - INCREMENT_SCRIPT, 1, self._rate_limit_key, self._expire, increment_by) + INCREMENT_SCRIPT, 1, self._rate_limit_key, self._expire, increment_by + ) if int(current_usage) > self._max_requests: raise TooManyRequests() @@ -160,13 +190,14 @@ def _is_rate_limit_supported(self): """ redis_version = self._redis.info()['redis_version'] is_supported = Version(redis_version) >= Version('2.6.0') + return bool(is_supported) def _reset(self): """ Deletes all keys that start with ‘rate_limit:’. """ - matching_keys = self._redis.scan_iter(match='{0}*'.format('rate_limit:*')) + matching_keys = self._redis.scan_iter(match="{0}*".format("rate_limit:*")) for rate_limit_key in matching_keys: self._redis.delete(rate_limit_key) @@ -180,7 +211,7 @@ def __init__(self, resource, max_requests, expire=None, redis_pool=REDIS_POOL): :param expire: seconds to wait before resetting counters (i.e. ‘60’) :param redis_pool: instance of redis.ConnectionPool. Default: ConnectionPool(host='127.0.0.1', port=6379, db=0) - """ + """ self.resource = resource self.max_requests = max_requests self.expire = expire