drf————源码分析>
- 认证源码分析
- 权限源码分析
- 频率类源码分析
三大认证的源码分析
之前读取的APIView的源码的执行流程中包装了新的request,执行了三大认证,执行视图类的方法,处理了全局异常
-
查看源码的入口
APIView的dispatch
-
进入后在APIView的dispatch的496行上下
self.initial(request, *args, **kwargs)中
-
查看APIView的initial
413行上下有三句话,分别是认证、权限、频率
self.perform_authentication(request) self.check_permissions(request) self.check_throttles(request)
这三个分别是三大认证的源码分析的读取入口
认证源码分析
-
读取认证类源码——APIView的perform_authentication(request)
def perform_authentication(self, request): request.user # 新的request
request是新的request
Request类中找user属性(方法),是个方法包装成了数据属性
-
点击request类找到user属性(方法)
def user(self): if not hasattr(self, '_user'): # Request类的对象中反射_user with wrap_attributeerrors(): self._authenticate() # 第一次会走这个代码 return self._user
查找到user属性后先进行判断函数是否包含_user属性,不包含则进行if内操作调用 self._authenticate()
-
点击 Request的self._authenticate()
def _authenticate(self): for authenticator in self.authenticators: # 配置在视图类中所有的认证类的对象 try: #(user_token.user, token) user_auth_tuple = authenticator.authenticate(self) # 调用认证类对象的authenticate except exceptions.APIException: self._not_authenticated() raise if user_auth_tuple is not None: self._authenticator = authenticator # 忽略 self.user, self.auth = user_auth_tuple # 解压赋值 return self._not_authenticated() # 认证类可以配置多个,但是如果有一个返回了两个值,后续的就不执行了
-
总结
认证类,要重写authenticate方法,认证通过返回两个值或None,
认证不通过抛AuthenticationFailed(继承了APIException)异常
权限源码分析
-
读取权限类源码 APIView的check_permissions(request)
def check_permissions(self, request): for permission in self.get_permissions(): # permission是咱们配置在视图类中权限类的对象,对象调用它的绑定方法has_permission # 对象调用自己的绑定方法会把自己传入(权限类的对象,request,视图类的对象) if not permission.has_permission(request, self): self.permission_denied( request, message=getattr(permission, 'message', None), code=getattr(permission, 'code', None) )
-
读取APIVIew的 self.get_permissions()
return [permission() for permission in self.permission_classes] """self.permission_classes 就是咱们在视图类中配的权限类的列表"""
所以这个get_permissions返回的是 在视图类中配的权限类的对象列表[UserTypePermession(),]
-
总结
为什么要写一个类,重写has_permission方法,
has_permission有三个参数,分别是 权限类的对象,request,视图类的对象
为什么一定要return True或False:
返回True是通过权限,Flase是未通过权限
messgage的作用:
调用self.messgage可以更改提示信息
频率类源码
-
读取频率类源码 APIView的check_throttles
def check_throttles(self, request): throttle_durations = [] for throttle in self.get_throttles(): if not throttle.allow_request(request, self): throttle_durations.append(throttle.wait()) def get_throttles(self): return [throttle() for throttle in self.throttle_classes]
要写频率类,必须重写allow_request方法,然后结束for循环返回值
-
读取 allow_request 源码
源码里执行的频率类的allow_request,读SimpleRateThrottle的allow_request
class SimpleRateThrottle(BaseThrottle): cache = default_cache timer = time.time cache_format = 'throttle_%(scope)s_%(ident)s' scope = None THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): # 只要类实例化得到对象就会执行,一执行,self.rate就有值了,而且self.num_requests和self.duration if not getattr(self, 'rate', None): # 去频率类中反射rate属性或方法,发现没有,返回了None,这个if判断就符合,执行下面的代码 self.rate = self.get_rate() #返回了 '3/m' # self.num_requests=3 # self.duration=60 self.num_requests, self.duration = self.parse_rate(self.rate) def get_rate(self): return self.THROTTLE_RATES[self.scope] # 字典取值,配置文件中咱们配置的字典{'ss': '3/m',},根据ss取到了 '3/m' def parse_rate(self, rate): if rate is None: return (None, None) # rate:字符串'3/m' 根据 / 切分,切成了 ['3','m'] # num=3,period=m num, period = rate.split('/') # num_requests=3 数字3 num_requests = int(num) # period='m' ---->period[0]--->'m' # {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] # duration=60 duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] # 3 60 return (num_requests, duration) def allow_request(self, request, view): if self.rate is None: return True # 咱们自己写的,返回什么就以什么做限制 咱们返回的是ip地址 # self.key=当前访问者的ip地址 self.key = self.get_cache_key(request, view) if self.key is None: return True # self.history 访问者的时间列表,从缓存中拿到,如果拿不到就是空列表,如果之前有 [时间2,时间1] self.history = self.cache.get(self.key, []) # 当前时间 self.now = self.timer() while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() if len(self.history) >= self.num_requests: return self.throttle_failure() return self.throttle_success()
-
总结
要写频率类,必须重写allow_request方法
返回True(没有到频率的限制)或False(到了频率的限制)
以后要再写频率类,只需要继承SimpleRateThrottle,重写get_cache_key,配置类属性scope,配置文件中配置一下就可以了
排序和过滤源码分析
继承了GenericAPIView+ListModelMixin,只要在视图类中配置filter_backends它就能实现过滤和排序
-
drf内置的过滤类(SearchFilter),排序类(OrderingFiler)
from rest_framework.filters import SearchFilter,OrderingFiler,BaseFilterBackend
在类中只需继承了GenericAPIView+ListModelMixin,配置filter_backends即可直接使用模块的过滤和排序
-
排序和过滤源码剖析
排序和过滤只有涉及到查看多个(list)才起作用
故需继承ListAPIView(继承了GenericAPIView+ListModelMixin)
ListModelMixin: def list(self, request, *args, **kwargs): # self.get_queryset()所有数据,经过了self.filter_queryset返回了qs # self.filter_queryset完成的过滤 queryset = self.filter_queryset(self.get_queryset()) # 如果有分页,走的分页----》视图类中配置了分页类 page = self.paginate_queryset(queryset) if page is not None: serializer = self.get_serializer(page, many=True) return self.get_paginated_response(serializer.data) # 如果没有分页,走正常的序列化,返回 serializer = self.get_serializer(queryset, many=True) return Response(serializer.data)
self.filter_queryset完成了过滤,当前在视图类中,self是视图类的对象
视图类中没找到去其父类找 ,找到 GenericAPIView 下的 filter_queryset
def filter_queryset(self, queryset): for backend in list(self.filter_backends): queryset = backend().filter_queryset(self.request, queryset, self) return queryset
重写 filter_queryset
drf内置的过滤类(SearchFilter),排序类(OrderingFiler)
即重写过滤类(SearchFilter),排序类(OrderingFiler)内部的filter_queryset方法即可自定义
排序
from rest_framework.filters import SearchFilter,OrderingFiler,BaseFilterBackend class 类名(BaseFilterBackend): def filter_queryset(self, request, queryset, view): # 重写方法 return queryset
-
总结:
写的过滤类要重写filter_queryset,返回qs(过滤或排序后)对象
后期如果不写过滤类,只要在视图类中重写filter_queryset,在里面实现过滤也可以
restframework-jwt执行流程
restframework-jwt 就是签发流程
本质就是登录接口,为了校验用户是否正确
如果正确签发token,写到了序列化类中,如果不正确返回错误
读取源码的入口:
obtain_jwt_token:核心代码--ObtainJSONWebToken.as_view()
-
ObtainJSONWebToken
视图类,实现了登录功能
class ObtainJSONWebToken(JSONWebTokenAPIView): serializer_class = JSONWebTokenSerializer
找其父类
class JSONWebTokenAPIView(APIView): # 局部禁用掉权限和认证 permission_classes = () authentication_classes = () def get_serializer_context(self): return { 'request': self.request, 'view': self, } def get_serializer_class(self): return self.serializer_class def get_serializer(self, *args, **kwargs): serializer_class = self.get_serializer_class() kwargs['context'] = self.get_serializer_context() return serializer_class(*args, **kwargs) def post(self, request, *args, **kwargs): # JSONWebTokenSerializer实例化得到一个序列号类的对象,传入前端传的只 serializer = self.get_serializer(data=request.data) if serializer.is_valid(): # 校验前端传入的数据是否合法: #1 字段自己的规则 2 局部钩子 3 全局钩子(序列化类的validate方法) # 获取当前登录用户和签发token是在序列化类中完成的 # 从序列化类对象中取出了当前登录用户 user = serializer.object.get('user') or request.user # # 从序列化类对象中取出了token token = serializer.object.get('token') # 自定义过 response_data = jwt_response_payload_handler(token, user, request) response = Response(response_data) return response return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
-
序列化类 JSONWebTokenSerializer
class JSONWebTokenSerializer(Serializer): def validate(self, attrs): credentials = { 'username': attrs.get('username'), 'password': attrs.get('password') } if all(credentials.values()): # auth的校验用户名和密码是否正确 user = authenticate(**credentials) if user: # 通过用户获得payload:{} payload = jwt_payload_handler(user) return { 'token': jwt_encode_handler(payload), 'user': user } else: # 根据用户名和密码查不到用户 raise serializers.ValidationError(msg) else: # 用户名和密码不传,传多了都不行 raise serializers.ValidationError(msg)
-
认证类 JSONWebTokenAuthentication
class JSONWebTokenAuthentication(BaseJSONWebTokenAuthentication): def get_jwt_value(self, request): # get_authorization_header(request)根据请求头中HTTP_AUTHORIZATION,取出token # jwt adsfasdfasdfad # auth=['jwt','真正的token'] auth = get_authorization_header(request).split() auth_header_prefix = api_settings.JWT_AUTH_HEADER_PREFIX.lower() if not auth: if api_settings.JWT_AUTH_COOKIE: return request.COOKIES.get(api_settings.JWT_AUTH_COOKIE) return None if smart_text(auth[0].lower()) != auth_header_prefix: return None if len(auth) == 1: msg = _('Invalid Authorization header. No credentials provided.') raise exceptions.AuthenticationFailed(msg) elif len(auth) > 2: msg = _('Invalid Authorization header. Credentials string ' 'should not contain spaces.') raise exceptions.AuthenticationFailed(msg) return auth[1]
其父类 BaseJSONWebTokenAuthentication---》authenticate
class BaseJSONWebTokenAuthentication(BaseAuthentication): def authenticate(self, request): # jwt_value前端传入的token jwt_value = self.get_jwt_value(request) # 前端没有传入token,return None,没有带token,认证类也能过,所有咱们才加权限类 if jwt_value is None: return None try: payload = jwt_decode_handler(jwt_value) # 验证token,token合法,返回payload except jwt.ExpiredSignature: msg = _('Signature has expired.') raise exceptions.AuthenticationFailed(msg) except jwt.DecodeError: msg = _('Error decoding signature.') raise exceptions.AuthenticationFailed(msg) except jwt.InvalidTokenError: raise exceptions.AuthenticationFailed() user = self.authenticate_credentials(payload) # 通过payload得到当前登录用户 return (user, jwt_value) # 后期的request.user就是当前登录用户
它这个认证类:只要带了token,request.user就有只,如果没带token,不管了,继续往后走