You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
109 lines
3.6 KiB
109 lines
3.6 KiB
import logging
|
|
import traceback
|
|
|
|
from django_redis import get_redis_connection
|
|
from rest_framework.throttling import ScopedRateThrottle
|
|
|
|
import ChaCeRndTrans
|
|
from ChaCeRndTrans.basic import CCAIResponse
|
|
|
|
err_logger = logging.getLogger('error')
|
|
|
|
class CustomThrottle(ScopedRateThrottle):
|
|
cache_format = 'throttle_%(path)s_%(scope)s_%(ident)s'
|
|
# cache = get_redis_connection('db2')
|
|
cache = get_redis_connection('default')
|
|
|
|
def parse_rate(self, rate):
|
|
"""
|
|
Given the request rate string, return a two tuple of:
|
|
<allowed number of requests>, <period of time in seconds>
|
|
"""
|
|
if rate is None:
|
|
return None
|
|
num, period = rate[:-1], rate[-1]
|
|
num = int(num)
|
|
duration = num * {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
|
|
return duration
|
|
|
|
def allow_request(self, request, view):
|
|
try:
|
|
# We can only determine the scope once we're called by the view.
|
|
self.scope = getattr(view, self.scope_attr, None)
|
|
|
|
# If a view does not have a `throttle_scope` always allow the request
|
|
if not self.scope:
|
|
return True
|
|
|
|
# Determine the allowed request rate as we normally would during
|
|
# the `__init__` call.
|
|
self.rate = self.get_rate()
|
|
self.duration = self.parse_rate(self.rate)
|
|
|
|
if self.rate is None:
|
|
return True
|
|
|
|
# 无缓存
|
|
self.key = self.get_cache_key(request, view)
|
|
if self.key is None:
|
|
return True
|
|
|
|
# self.history = self.cache.get(self.cache.make_key(self.key), None)
|
|
tmp_history = self.cache.get(self.key)
|
|
self.history = float(tmp_history) if tmp_history else None
|
|
self.now = self.timer()
|
|
# Drop any requests from the history which have now passed the
|
|
# throttle duration
|
|
if self.history is None:
|
|
return self.throttle_success()
|
|
if self.history >= self.now - self.duration:
|
|
return self.throttle_failure()
|
|
return self.throttle_success()
|
|
except Exception as e:
|
|
err_logger.error("keys: %s throttle failed: \n%s" % (self.key or None, traceback.format_exc(),))
|
|
raise Exception("限流器错误")
|
|
|
|
def throttle_success(self):
|
|
"""
|
|
Inserts the current request's timestamp along with the key
|
|
into the cache.
|
|
"""
|
|
|
|
self.history = self.now
|
|
self.cache.set(self.key, self.history, self.duration)
|
|
return True
|
|
|
|
# def throttle_failure(self):
|
|
# """
|
|
# Called when a request to the API has failed due to throttling.
|
|
# """
|
|
# return False
|
|
|
|
def get_cache_key(self, request, view):
|
|
"""
|
|
If `view.throttle_scope` is not set, don't apply this throttle.
|
|
|
|
Otherwise generate the unique cache key by concatenating the user id
|
|
with the '.throttle_scope` property of the view.
|
|
"""
|
|
if request.user.is_authenticated:
|
|
ident = request.user.pk
|
|
else:
|
|
ident = self.get_ident(request)
|
|
|
|
# return self.cache.make_key(self.cache_format % {
|
|
# 'scope': self.scope,
|
|
# 'ident': ident
|
|
# })
|
|
return self.cache_format % {
|
|
'path': request.path,
|
|
'scope': self.scope,
|
|
'ident': ident
|
|
}
|
|
|
|
def wait(self):
|
|
if self.history:
|
|
seconds = self.duration - (self.now - self.history)
|
|
else:
|
|
seconds = self.duration
|
|
return seconds
|