限流
可以对接口访问的频次进行限制,以减轻服务器压力。
自定义限流思路: 实现机制:自定义一个字典,key 为唯一标识,可以是 ip 地址、用户名、用户id 等,value 为一个列表,存储时间,如限制十秒内访问三次:
{"ip": [16:20:20,16:20:30,16:20:40]}
思路:
如果用户没访问过就创建改用户的访问记录并记录时间
如果用户已经访问过,获取用户访问历史,并且判断最后一次访问时间是否小于当前时间减去10秒
如果小于说明是满足间隔时间,那么就把最后一条数据删除
限制判断10秒内访问三次,判断长度是否小于三,
如果小于三说明说明是可访问,需要在第0条插入访问的时间
自定义限流代码实现: 需要借助 drf 的基础限流类:BaseThrottle 并在里面重写 allow_request,这条要记清楚,不重写会报错,还有一个 wait 方法,返回动态的秒数。
VISIT_RECORD = {}
class MyThrottle(BaseThrottle):
def __init__(self):
self.history = None
def allow_request(self, request, view):
remote_addr = request.META.get('REMOTE_ADDR')
current_time = time.time()
if remote_addr not in VISIT_RECORD:
VISIT_RECORD[remote_addr] = current_time
return True
history = VISIT_RECORD.get(remote_addr)
self.history = history
while history and history[-1] < current_time - 10:
history.pop()
if len(history) < 3:
history.insert(0, current_time)
return True
def wait(self):
current_time = time.time()
return 10 - (current_time - self.history[-1])
上述是自己定义的限流逻辑,实际上 drf 内部已经实现了限流功能,只需要导入就能使用:
SimpleRateThrottle 源码解读:简单速率限制
class SimpleRateThrottle(BaseThrottle):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
The rate (requests / seconds) is set by a `rate` attribute on the View
class. The attribute is a string of the form 'number_of_requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
def get_cache_key(self, request, view):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
May return `None` if the request should not be throttled.
"""
raise NotImplementedError('.get_cache_key() must be overridden')
def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
"""
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
def allow_request(self, request, view):
"""
Implement the check to see if the request should be throttled.
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
"""
return False
def wait(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
自定义限流类:
class MySimpleRateThrottle(SimpleRateThrottle):
scope = "xxoo"
def get_cache_key(self, request, view):
return self.get_ident(request)
def get_ident(self, request):
"""
Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
配置文件:
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_RATES": {
"xxoo": "3/m"
},
}
APIView 中需要加上 throttle_classes:
class UserInfo(APIView):
throttle_classes = [MySimpleRateThrottle, ]
效果:
内置限流类:
匿名限流类:针对未登录(匿名)用户的限流控制类
class AnonRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a anonymous users.
The IP address of the request will be used as the unique cache key.
"""
scope = 'anon'
def get_cache_key(self, request, view):
if request.user.is_authenticated:
return None
return self.cache_format % {
'scope': self.scope,
'ident': self.get_ident(request)
}
认证用户限流类:针对登录(认证)用户的限流控制类
class UserRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique cache key if the user is
authenticated. For anonymous requests, the IP address of the request will
be used.
"""
scope = 'user'
def get_cache_key(self, request, view):
if request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
统一限流类:针对登录(认证)用户和匿名用户
class ScopedRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls by different amounts for various parts of
the API. Any view that has the `throttle_scope` property set will be
throttled. The unique cache key will be generated by concatenating the
user id of the request, and the scope of the view being accessed.
"""
scope_attr = 'throttle_scope'
def __init__(self):
pass
def allow_request(self, request, view):
self.scope = getattr(view, self.scope_attr, None)
if not self.scope:
return True
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
return super().allow_request(request, view)
def get_cache_key(self, request, view):
"""
If `view.throttle_scope` is not set, don't apply this throttle.
Otherwise generate the unique cache key by concatenating the user id
with the '.throttle_scope` property of the view.
"""
if request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
配置文件:
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle'
),
'DEFAULT_THROTTLE_RATES': {
'user': '5/minute',
'anon': '3/minute',
},
}
针对范围的配置文件稍微不一样:
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.ScopedRateThrottle',
),
'DEFAULT_THROTTLE_RATES': {
'list': '3/m',
'get': '5/m'
},
}
视图中:
class ListView(APIView):
throttle_scope = 'list'
...
class DetailView(APIView):
throttle_scope = 'list'
...
class GetView(APIView):
throttle_scope = 'get'
...
版本
系统都有版本,这是因为在迭代中的一种标记,记录过程。在常见的后端中一般会再 url 中上加上版本,如:/api/v1/users,这种概念在 restful api 中被提出,个人理解是和系统迭代挂钩。 在 drf 中也有版本的概念,估计是为了兼容 restful api 这种概念吧,实际上个人感觉是不需要,在业务开发时一般都自己定义了,不过也看下 drf 中版本的玩法吧,drf 中分两种,一种是通过参数获取,另一种是配置在 url 中。
QueryParameterVersioning:参数传递
class QueryParameterVersioning(BaseVersioning):
"""
GET /something/?version=0.1 HTTP/1.1 通过路径传参获取版本,配置文件可配置
Host: example.com
Accept: application/json
"""
invalid_version_message = _('Invalid version in query parameter.')
def determine_version(self, request, *args, **kwargs):
version = request.query_params.get(self.version_param, self.default_version)
if not self.is_allowed_version(version):
raise exceptions.NotFound(self.invalid_version_message)
return version
配置文件:
REST_FRAMEWORK = {
"DEFAULT_VERSION": "v1",
"ALLOWED_VERSIONS": ["v1", "v2"],
"VERSION_PARAM": "version",
}
视图中:
class UserInfo(APIView):
versioning_class = QueryParameterVersioning
def get(self, request):
version = request._request.GET.get("version")
print(version)
print(request.version)
return Response()
效果:
URLPathVersioning:路径中(推荐使用)
class URLPathVersioning(BaseVersioning):
"""
To the client this is the same style as `NamespaceVersioning`.
The difference is in the backend - this implementation uses
Django's URL keyword arguments to determine the version.
An example URL conf for two views that accept two different versions.
urlpatterns = [
url(r'^(?P<version>[v1|v2]+)/users/$', users_list, name='users-list'),
url(r'^(?P<version>[v1|v2]+)/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
]
GET /1.0/something/ HTTP/1.1
Host: example.com
Accept: application/json
"""
项目 urls.py 中:
from django.contrib import admin
from django.conf.urls import url, include
urlpatterns = [
url('admin/', admin.site.urls),
url('api/', include("djcelerytest.urls"), name="xxoo"),
]
app urls.py 中:
from django.conf.urls import url
from djcelerytest.views import test, UserInfo
urlpatterns = [
url('test/', test),
url(r'^(?P<version>[v1|v2]+)/info/', UserInfo.as_view()),
]
视图中:
class UserInfo(APIView):
versioning_class = URLPathVersioning
def get(self, request, *args, **kwargs):
print(request.version)
return Response()
版本源码解读
initial 中在执行认证、权限之前会先获取版本号。 determine_version:
def determine_version(self, request, *args, **kwargs):
"""
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
"""
if self.versioning_class is None:
return (None, None)
scheme = self.versioning_class()
return (scheme.determine_version(request, *args, **kwargs), scheme)
determine_version :
def determine_version(self, request, *args, **kwargs):
version = kwargs.get(self.version_param, self.default_version)
if version is None:
version = self.default_version
if not self.is_allowed_version(version):
raise exceptions.NotFound(self.invalid_version_message)
return version
赋值给 request:
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
scheme :反向解析时用针对 QueryParameterVersioning urls.py:
urlpatterns = [
url(r'^(?P<version>[v1|v2]+)/info/$', UserInfo.as_view(), name="xxoo"),
]
class UserInfo(APIView):
versioning_class = QueryParameterVersioning
def get(self, request, *args, **kwargs):
print(request.version)
print(request.versioning_scheme.reverse(viewname="xxoo", request=request))
return Response()
其它版本类:
class NamespaceVersioning(BaseVersioning):
"""
To the client this is the same style as `URLPathVersioning`.
The difference is in the backend - this implementation uses
Django's URL namespaces to determine the version.
An example URL conf that is namespaced into two separate versions
# users/urls.py
urlpatterns = [
url(r'^/users/$', users_list, name='users-list'),
url(r'^/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
]
"""
class HostNameVersioning(BaseVersioning):
"""
GET /something/ HTTP/1.1
Host: v1.example.com
Accept: application/json
"""
class AcceptHeaderVersioning(BaseVersioning):
"""
GET /something/ HTTP/1.1
Host: example.com
Accept: application/json; version=1.0
"""
解析器
我们在 request.data 或是 request.query_params 中的数据实际都是 drf 给我们做了一层解析,根据请求头的不同获取实例化不同的数据类型,下面直接看源码:
if self.content_type == 'multipart/form-data':
if hasattr(self, '_body'):
data = BytesIO(self._body)
else:
data = self
try:
self._post, self._files = self.parse_file_upload(self.META, data)
except MultiPartParserError:
self._mark_post_parse_error()
raise
elif self.content_type == 'application/x-www-form-urlencoded':
self._post, self._files = QueryDict(self.body, encoding=self._encoding), MultiValueDict()
else:
self._post, self._files = QueryDict(encoding=self._encoding), MultiValueDict()
DRF 中的解析器:
JSONParser: 只支持Content-Type = application/json
class JSONParser(BaseParser):
"""
Parses JSON-serialized data.
"""
media_type = 'application/json'
renderer_class = renderers.JSONRenderer
strict = api_settings.STRICT_JSON
如果是其他类型会提示:
{
"detail": "不支持请求中的媒体类型 “application/x-www-form-urlencoded”。"
}
FormParser: 只支持Content-Type = ‘application/x-www-form-urlencoded’
class FormParser(BaseParser):
"""
Parser for form data.
"""
media_type = 'application/x-www-form-urlencoded'
MultiPartParser: 只支持Content-Type = ‘multipart/form-data’ 文件上传
class MultiPartParser(BaseParser):
"""
Parser for multipart form data, which may include file data.
"""
media_type = 'multipart/form-data'
**FileUploadParser: ** media_type = ‘/’ 什么类型都可以
class FileUploadParser(BaseParser):
"""
Parser for file upload data.
"""
media_type = '*/*'
errors = {
'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream',
'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.',
}
视图中需要指定解析器:
class UserInfo(APIView):
parser_classes = [JSONParser]
也可以全局配置:
REST_FRAMEWORK = {
"DEFAULT_PARSER_CLASSES": ["rest_framework.parsers.JSONParser"],
}
dispatch 在封装 request 时也会把 parser_classes 封装进去:
def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""
parser_context = self.get_parser_context(request)
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
Django REST Framework 全局配置文件
疑问?那么多配置,如果自己没有学到,或者压根不知道,那么默认是从哪里来的?是不是都要配置?其实大可不必,DRF 默认也会给我们配置一些。从下面导入:
from rest_framework.settings import api_settings
默认配置:
DEFAULTS = {
'DEFAULT_RENDERER_CLASSES': [
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
],
'DEFAULT_PARSER_CLASSES': [
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser'
],
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.BasicAuthentication'
],
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.AllowAny',
],
'DEFAULT_THROTTLE_CLASSES': [],
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
'DEFAULT_VERSIONING_CLASS': None,
'DEFAULT_PAGINATION_CLASS': None,
'DEFAULT_FILTER_BACKENDS': [],
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
'DEFAULT_THROTTLE_RATES': {
'user': None,
'anon': None,
},
'NUM_PROXIES': None,
'PAGE_SIZE': None,
'SEARCH_PARAM': 'search',
'ORDERING_PARAM': 'ordering',
'DEFAULT_VERSION': None,
'ALLOWED_VERSIONS': None,
'VERSION_PARAM': 'version',
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
'NON_FIELD_ERRORS_KEY': 'non_field_errors',
'TEST_REQUEST_RENDERER_CLASSES': [
'rest_framework.renderers.MultiPartRenderer',
'rest_framework.renderers.JSONRenderer'
],
'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',
'URL_FORMAT_OVERRIDE': 'format',
'FORMAT_SUFFIX_KWARG': 'format',
'URL_FIELD_NAME': 'url',
'DATE_FORMAT': ISO_8601,
'DATE_INPUT_FORMATS': [ISO_8601],
'DATETIME_FORMAT': ISO_8601,
'DATETIME_INPUT_FORMATS': [ISO_8601],
'TIME_FORMAT': ISO_8601,
'TIME_INPUT_FORMATS': [ISO_8601],
'UNICODE_JSON': True,
'COMPACT_JSON': True,
'STRICT_JSON': True,
'COERCE_DECIMAL_TO_STRING': True,
'UPLOADED_FILES_USE_URL': True,
'HTML_SELECT_CUTOFF': 1000,
'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",
'SCHEMA_COERCE_PATH_PK': True,
'SCHEMA_COERCE_METHOD_NAMES': {
'retrieve': 'read',
'destroy': 'delete'
},
}
总结
限流:限流工作中一般用的很少,针对大并发一般是想办法优化,不会截流 版本:drf 中版本没必要用,url 中自己写比使用内置的可读性高很多了 解析器:使用默认的就行,一般工作中也不会刻意配置
参考文献
drf 官网:https://www.django-rest-framework.org/api-guide/settings/ drf 中文文档:https://q1mi.github.io/Django-REST-framework-documentation/
|