import json
import re
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from functools import wraps
from ipaddress import IPv4Network, ip_address
from json import JSONDecodeError
from typing import List, Optional
from q2_sdk.core.cache import get_cache
from tornado.httputil import HTTPServerRequest
from tornado.web import HTTPError
from q2_sdk.core.q2_logging.logger import Q2LoggerType
[docs]
class RateLimitException(Exception): ...
[docs]
@dataclass
class RateLimit:
"""
Represent a single rate limit
"""
tokens: int
last_refilled: datetime
[docs]
def serialize(self) -> str:
"""
Used to return a json representation of the self Ratelimit. Datetime format is ISO 8601 format.
:return: JSON representation of the RateLimit
"""
return json.dumps(
self,
default=lambda elem: elem.isoformat()
if isinstance(elem, datetime)
else elem.__dict__,
)
[docs]
@staticmethod
def deserialize(json_str: str):
"""
Takes a JSON string and parses it into a RateLimit
:param json_str: The JSON string containing RateLimit information
:return: A RateLimit object
"""
try:
data = json.loads(json_str)
return RateLimit(
tokens=data["tokens"],
last_refilled=datetime.fromisoformat(data["last_refilled"]),
)
except (JSONDecodeError, KeyError, ValueError, TypeError) as e:
raise RateLimitException(
"There was an error parsing the RateLimit json data"
) from e
[docs]
class RateLimiter:
"""
Each user gets a maximum number of tokens representing allowed
requests. These tokens deplete by one each request, refilling at a rate
of refill_amount / refill_period (in seconds).
For instance:
Max_tokens: 10
refill_period: 5
refill_amount: 2
A user can use the extension to which this rate limiter is bound until her token count reaches 0.
She hits the extension 10 times in the first second and is now denied access.
5 seconds later, she is granted 2 more tokens.
Every 5 seconds 2 more tokens are added to the count, a maximum of 10 again.
The state is saved in Memcached under the two key-value pairs ``ratelimit_count_`` and ``ratelimit_update_``
"""
def __init__(
self,
logger: Q2LoggerType,
max_tokens: int,
refill_period: int,
refill_amount: int,
request: HTTPServerRequest,
whitelist_networks: Optional[List[IPv4Network]] = None,
blacklist_networks: Optional[List[IPv4Network]] = None,
whitelist_regex=None,
blacklist_regex=None,
name="generic",
segment_by_ip=True,
is_proxied=True,
):
"""
:param max_tokens: Total number of attempts allowed before denial
:param refill_period: In seconds
:param refill_amount: Number to add each refill period
:param request: User Request to reference IP and other info
:param whitelist_networks: List of IPv4Networks that will be automatically allowed.
i.e. IPv4Network('13.249.59.85/32')
:param blacklist_networks: List of IPv4Networks that will be automatically denied.
i.e. IPv4Network('13.249.59.85/32')
:param whitelist_regex: (Deprecated) IP addresses that match this pattern will be automatically allowed
:param blacklist_regex: (Deprecated) IP addresses that match this pattern will be automatically denied
:param name: Will be appended to the cache key for uniqueness
:param segment_by_ip: If True, each IP address gets its own rate limiting bucket
:param is_proxied: If True, The ip address is pulled from the header x-forwarded-for value over the remote_ip
value
"""
self.logger = logger
self.max_tokens = max_tokens
self.refill_period = timedelta(seconds=refill_period)
self.refill_amount = refill_amount
self.remaining_tokens = self.max_tokens
self.name = name
self.ip_addr = request.remote_ip
self.segment_by_ip = segment_by_ip
if is_proxied:
self.ip_addr = self._get_prox_ip_address(request)
if whitelist_regex:
self._log_deprecation("whitelist_regex", "whitelist_networks")
if blacklist_regex:
self._log_deprecation("blacklist_regex", "blacklist_networks")
self.whitelist_regex = whitelist_regex
self.blacklist_regex = blacklist_regex
self.whitelist_networks = whitelist_networks or []
self.blacklist_networks = blacklist_networks or []
def _log_deprecation(self, old, new):
self.logger.warning(
"%s has been deprecated for RateLimiters. "
"Please use %s instead in %s" % (old, new, self.name)
)
def _log_whitelist_msg(self):
self.logger.debug(
"%s matches whitelist. Bypassing rate limiter: %s"
% (self.ip_addr, self.name)
)
def _log_blacklist_msg(self):
self.logger.warning("%s matches blacklist. Denying access" % self.ip_addr)
@property
def ratelimit_key(self) -> str:
"""
Returns the proper key depending on if the `segment_by_ip flag is set true`
:return: The key to select the ratelimit count
"""
if self.segment_by_ip:
return f"ratelimit_{self.name}_{self.ip_addr}"
return f"ratelimit_{self.name}"
def _get_prox_ip_address(self, request) -> str:
"""
Given the request object return the ip address of the forwarded host.
:param request: The request object
:return: String containing the ip address
"""
remote_address = request.remote_ip
forwarded_ip_string = request.headers.get("x-forwarded-for", None)
if forwarded_ip_string:
forwarded_ips = forwarded_ip_string.replace(" ", "").split(",")
if len(forwarded_ips) > 1:
self.logger.debug(
"Multiple forwarded ips received %s" % forwarded_ip_string
)
if forwarded_ips:
remote_address = str(
ip_address(forwarded_ips[0])
) # some simple ip address validation and sanitization
return remote_address
@property
def _matches_whitelist(self) -> bool:
"""
Verifies that the given ip address is in the list of provided networks or the deprecated regex
:return: True if ip address fits into the given networks or regex
"""
if len(self.whitelist_networks) > 0:
address = ip_address(self.ip_addr)
for network in self.whitelist_networks:
if address in network:
self._log_whitelist_msg()
return True
elif self.whitelist_regex and re.search(self.whitelist_regex, self.ip_addr):
self._log_whitelist_msg()
return True
return False
@property
def _matches_blacklist(self):
"""
Verifies that the given ip address is in the list of provided networks or the deprecated regex
:return: True if ip address fits into the given networks or regex
"""
if len(self.blacklist_networks) > 0:
address = ip_address(self.ip_addr)
for network in self.blacklist_networks:
if address in network:
self._log_blacklist_msg()
return True
elif self.blacklist_regex and re.search(self.blacklist_regex, self.ip_addr):
self._log_blacklist_msg()
return True
return False
[docs]
def is_allowed(self):
"""
Checks if the request is allowed to proceed. This function reads the cached memcached data and updates it to
the current value. Based on the object values and the amount of time that has passed, this will return a
true or false result.
:return: True if the request is able to proceed, False otherwise
"""
if self._matches_whitelist:
return True
if self._matches_blacklist:
return False
now = datetime.now(timezone.utc)
cache = get_cache()
# Restore the limit from cache or instantiate a new one
try:
cached_ratelimit = cache.get(self.ratelimit_key)
if cached_ratelimit is None:
limit = RateLimit(tokens=self.max_tokens, last_refilled=now)
else:
limit = RateLimit.deserialize(cached_ratelimit)
except RateLimitException as e:
raise RateLimitException(
"There was an error parsing a RateLimit from key %s"
% self.ratelimit_key
) from e
# Calculate the number of refills that have occurred between "last_updated" and "now"
# limiting to a positing time difference
number_of_refills = int(
max((now - limit.last_refilled), timedelta()) / self.refill_period
)
# Update Token count with a min/max bound rounding to the nearest whole token
self.remaining_tokens = min(
limit.tokens + number_of_refills * self.refill_amount - 1,
self.max_tokens,
)
limit.tokens = max(0, self.remaining_tokens)
limit.last_refilled += number_of_refills * self.refill_period
cache.set(
self.ratelimit_key,
limit.serialize(),
expire=abs(
int(
(self.max_tokens / self.refill_amount)
* self.refill_period.total_seconds()
)
),
)
if self.remaining_tokens < 0:
self.logger.warning(
"Rate Limiter: %s. There are no remaining tokens for ip %s. Denying access"
% (self.name, self.ip_addr)
)
return False
self.logger.debug(
"\n".join([
f"Rate Limiter: {self.name}{f' IP: {self.ip_addr}' if self.segment_by_ip else ''}",
f"Max Amount: {self.max_tokens}",
f"Refill Period: {self.refill_period}",
f"Refill Amount: {self.refill_amount}",
f"Remaining Tokens: {self.remaining_tokens}",
])
)
return True
[docs]
def rate_limit(
max_tokens=1000,
refill_period=60,
refill_amount=18,
whitelist_networks=None,
blacklist_networks=None,
):
"""
Decorator for the ProxyRateLimiter Class. Optional arguments to customize the rate limit.
The caching naming convention uses the __name__ methodology similar to Python's default logging convention. For
example:
@rate_limit(max_tokens=1000)
submit():
This wraps the submit() function inside the extension.py file will result in a cache located at approximately
[PROJECTNAME].extension.submit_[IP ADDRESS]
:param max_tokens: The maximum number of tokens that will be accumulated
:param refill_period: The number of seconds between when the refill amount is added to the token count
:param refill_amount: The number of Tokens that are added every refill period
:param whitelist_networks: List of IPv4Networks that will be automatically allowed.
i.e. IPv4Network('13.249.59.85/32')
:param blacklist_networks: List of IPv4Networks that will be automatically block.
i.e., IPv4Network('13.249.59.85/32')
:return: function wrapper
"""
def actual_decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
limiter = RateLimiter(
logger=self.logger,
max_tokens=max_tokens,
refill_period=refill_period,
refill_amount=refill_amount,
request=self.request,
whitelist_networks=whitelist_networks,
blacklist_networks=blacklist_networks,
segment_by_ip=True,
is_proxied=True,
name=func.__qualname__,
)
if not limiter.is_allowed():
raise HTTPError(429)
return await func(self, *args, **kwargs)
return wrapper
return actual_decorator