Source code for q2_sdk.core.rate_limiter

import json
import re
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from functools import wraps
from ipaddress import IPv4Network, ip_address
from json import JSONDecodeError
from typing import List, Optional

from q2_sdk.core.cache import get_cache
from tornado.httputil import HTTPServerRequest
from tornado.web import HTTPError

from q2_sdk.core.q2_logging.logger import Q2LoggerType


[docs] class RateLimitException(Exception): ...
[docs] @dataclass class RateLimit: """ Represent a single rate limit """ tokens: int last_refilled: datetime
[docs] def serialize(self) -> str: """ Used to return a json representation of the self Ratelimit. Datetime format is ISO 8601 format. :return: JSON representation of the RateLimit """ return json.dumps( self, default=lambda elem: elem.isoformat() if isinstance(elem, datetime) else elem.__dict__, )
[docs] @staticmethod def deserialize(json_str: str): """ Takes a JSON string and parses it into a RateLimit :param json_str: The JSON string containing RateLimit information :return: A RateLimit object """ try: data = json.loads(json_str) return RateLimit( tokens=data["tokens"], last_refilled=datetime.fromisoformat(data["last_refilled"]), ) except (JSONDecodeError, KeyError, ValueError, TypeError) as e: raise RateLimitException( "There was an error parsing the RateLimit json data" ) from e
[docs] class RateLimiter: """ Each user gets a maximum number of tokens representing allowed requests. These tokens deplete by one each request, refilling at a rate of refill_amount / refill_period (in seconds). For instance: Max_tokens: 10 refill_period: 5 refill_amount: 2 A user can use the extension to which this rate limiter is bound until her token count reaches 0. She hits the extension 10 times in the first second and is now denied access. 5 seconds later, she is granted 2 more tokens. Every 5 seconds 2 more tokens are added to the count, a maximum of 10 again. The state is saved in Memcached under the two key-value pairs ``ratelimit_count_`` and ``ratelimit_update_`` """ def __init__( self, logger: Q2LoggerType, max_tokens: int, refill_period: int, refill_amount: int, request: HTTPServerRequest, whitelist_networks: Optional[List[IPv4Network]] = None, blacklist_networks: Optional[List[IPv4Network]] = None, whitelist_regex=None, blacklist_regex=None, name="generic", segment_by_ip=True, is_proxied=True, ): """ :param max_tokens: Total number of attempts allowed before denial :param refill_period: In seconds :param refill_amount: Number to add each refill period :param request: User Request to reference IP and other info :param whitelist_networks: List of IPv4Networks that will be automatically allowed. i.e. IPv4Network('13.249.59.85/32') :param blacklist_networks: List of IPv4Networks that will be automatically denied. i.e. IPv4Network('13.249.59.85/32') :param whitelist_regex: (Deprecated) IP addresses that match this pattern will be automatically allowed :param blacklist_regex: (Deprecated) IP addresses that match this pattern will be automatically denied :param name: Will be appended to the cache key for uniqueness :param segment_by_ip: If True, each IP address gets its own rate limiting bucket :param is_proxied: If True, The ip address is pulled from the header x-forwarded-for value over the remote_ip value """ self.logger = logger self.max_tokens = max_tokens self.refill_period = timedelta(seconds=refill_period) self.refill_amount = refill_amount self.remaining_tokens = self.max_tokens self.name = name self.ip_addr = request.remote_ip self.segment_by_ip = segment_by_ip if is_proxied: self.ip_addr = self._get_prox_ip_address(request) if whitelist_regex: self._log_deprecation("whitelist_regex", "whitelist_networks") if blacklist_regex: self._log_deprecation("blacklist_regex", "blacklist_networks") self.whitelist_regex = whitelist_regex self.blacklist_regex = blacklist_regex self.whitelist_networks = whitelist_networks or [] self.blacklist_networks = blacklist_networks or [] def _log_deprecation(self, old, new): self.logger.warning( "%s has been deprecated for RateLimiters. " "Please use %s instead in %s" % (old, new, self.name) ) def _log_whitelist_msg(self): self.logger.debug( "%s matches whitelist. Bypassing rate limiter: %s" % (self.ip_addr, self.name) ) def _log_blacklist_msg(self): self.logger.warning("%s matches blacklist. Denying access" % self.ip_addr) @property def ratelimit_key(self) -> str: """ Returns the proper key depending on if the `segment_by_ip flag is set true` :return: The key to select the ratelimit count """ if self.segment_by_ip: return f"ratelimit_{self.name}_{self.ip_addr}" return f"ratelimit_{self.name}" def _get_prox_ip_address(self, request) -> str: """ Given the request object return the ip address of the forwarded host. :param request: The request object :return: String containing the ip address """ remote_address = request.remote_ip forwarded_ip_string = request.headers.get("x-forwarded-for", None) if forwarded_ip_string: forwarded_ips = forwarded_ip_string.replace(" ", "").split(",") if len(forwarded_ips) > 1: self.logger.debug( "Multiple forwarded ips received %s" % forwarded_ip_string ) if forwarded_ips: remote_address = str( ip_address(forwarded_ips[0]) ) # some simple ip address validation and sanitization return remote_address @property def _matches_whitelist(self) -> bool: """ Verifies that the given ip address is in the list of provided networks or the deprecated regex :return: True if ip address fits into the given networks or regex """ if len(self.whitelist_networks) > 0: address = ip_address(self.ip_addr) for network in self.whitelist_networks: if address in network: self._log_whitelist_msg() return True elif self.whitelist_regex and re.search(self.whitelist_regex, self.ip_addr): self._log_whitelist_msg() return True return False @property def _matches_blacklist(self): """ Verifies that the given ip address is in the list of provided networks or the deprecated regex :return: True if ip address fits into the given networks or regex """ if len(self.blacklist_networks) > 0: address = ip_address(self.ip_addr) for network in self.blacklist_networks: if address in network: self._log_blacklist_msg() return True elif self.blacklist_regex and re.search(self.blacklist_regex, self.ip_addr): self._log_blacklist_msg() return True return False
[docs] def is_allowed(self): """ Checks if the request is allowed to proceed. This function reads the cached memcached data and updates it to the current value. Based on the object values and the amount of time that has passed, this will return a true or false result. :return: True if the request is able to proceed, False otherwise """ if self._matches_whitelist: return True if self._matches_blacklist: return False now = datetime.now(timezone.utc) cache = get_cache() # Restore the limit from cache or instantiate a new one try: cached_ratelimit = cache.get(self.ratelimit_key) if cached_ratelimit is None: limit = RateLimit(tokens=self.max_tokens, last_refilled=now) else: limit = RateLimit.deserialize(cached_ratelimit) except RateLimitException as e: raise RateLimitException( "There was an error parsing a RateLimit from key %s" % self.ratelimit_key ) from e # Calculate the number of refills that have occurred between "last_updated" and "now" # limiting to a positing time difference number_of_refills = int( max((now - limit.last_refilled), timedelta()) / self.refill_period ) # Update Token count with a min/max bound rounding to the nearest whole token self.remaining_tokens = min( limit.tokens + number_of_refills * self.refill_amount - 1, self.max_tokens, ) limit.tokens = max(0, self.remaining_tokens) limit.last_refilled += number_of_refills * self.refill_period cache.set( self.ratelimit_key, limit.serialize(), expire=abs( int( (self.max_tokens / self.refill_amount) * self.refill_period.total_seconds() ) ), ) if self.remaining_tokens < 0: self.logger.warning( "Rate Limiter: %s. There are no remaining tokens for ip %s. Denying access" % (self.name, self.ip_addr) ) return False self.logger.debug( "\n".join([ f"Rate Limiter: {self.name}{f' IP: {self.ip_addr}' if self.segment_by_ip else ''}", f"Max Amount: {self.max_tokens}", f"Refill Period: {self.refill_period}", f"Refill Amount: {self.refill_amount}", f"Remaining Tokens: {self.remaining_tokens}", ]) ) return True
[docs] def rate_limit( max_tokens=1000, refill_period=60, refill_amount=18, whitelist_networks=None, blacklist_networks=None, ): """ Decorator for the ProxyRateLimiter Class. Optional arguments to customize the rate limit. The caching naming convention uses the __name__ methodology similar to Python's default logging convention. For example: @rate_limit(max_tokens=1000) submit(): This wraps the submit() function inside the extension.py file will result in a cache located at approximately [PROJECTNAME].extension.submit_[IP ADDRESS] :param max_tokens: The maximum number of tokens that will be accumulated :param refill_period: The number of seconds between when the refill amount is added to the token count :param refill_amount: The number of Tokens that are added every refill period :param whitelist_networks: List of IPv4Networks that will be automatically allowed. i.e. IPv4Network('13.249.59.85/32') :param blacklist_networks: List of IPv4Networks that will be automatically block. i.e., IPv4Network('13.249.59.85/32') :return: function wrapper """ def actual_decorator(func): @wraps(func) async def wrapper(self, *args, **kwargs): limiter = RateLimiter( logger=self.logger, max_tokens=max_tokens, refill_period=refill_period, refill_amount=refill_amount, request=self.request, whitelist_networks=whitelist_networks, blacklist_networks=blacklist_networks, segment_by_ip=True, is_proxied=True, name=func.__qualname__, ) if not limiter.is_allowed(): raise HTTPError(429) return await func(self, *args, **kwargs) return wrapper return actual_decorator