首页 > 编程语言 >drf————源码分析

drf————源码分析

时间:2023-09-12 20:58:21浏览次数:37  
标签:分析 return get queryset self request 源码 user drf

drf————源码分析>

  • 认证源码分析
  • 权限源码分析
  • 频率类源码分析

三大认证的源码分析

之前读取的APIView的源码的执行流程中包装了新的request,执行了三大认证,执行视图类的方法,处理了全局异常

  • 查看源码的入口

    APIView的dispatch

  • 进入后在APIView的dispatch的496行上下

    self.initial(request, *args, **kwargs)中

  • 查看APIView的initial

    413行上下有三句话,分别是认证、权限、频率

     self.perform_authentication(request)
     self.check_permissions(request)
     self.check_throttles(request)
    

    这三个分别是三大认证的源码分析的读取入口

认证源码分析

  • 读取认证类源码——APIView的perform_authentication(request)

     def perform_authentication(self, request):
         request.user  # 新的request
    

    request是新的request

    Request类中找user属性(方法),是个方法包装成了数据属性

  • 点击request类找到user属性(方法)

     def user(self):
         if not hasattr(self, '_user'): # Request类的对象中反射_user
             with wrap_attributeerrors():
                 self._authenticate()  # 第一次会走这个代码
        	return self._user
    

    查找到user属性后先进行判断函数是否包含_user属性,不包含则进行if内操作调用 self._authenticate()

  • 点击 Request的self._authenticate()

     def _authenticate(self):
         for authenticator in self.authenticators: # 配置在视图类中所有的认证类的对象 
             try:
                 #(user_token.user, token)
                 user_auth_tuple = authenticator.authenticate(self) 
                 # 调用认证类对象的authenticate
             except exceptions.APIException:
                 self._not_authenticated()
                 raise
    
             if user_auth_tuple is not None:
                 self._authenticator = authenticator # 忽略
                 self.user, self.auth = user_auth_tuple # 解压赋值
                 return  self._not_authenticated()
                     # 认证类可以配置多个,但是如果有一个返回了两个值,后续的就不执行了
    
  • 总结

    认证类,要重写authenticate方法,认证通过返回两个值或None,

    认证不通过抛AuthenticationFailed(继承了APIException)异常

权限源码分析

  • 读取权限类源码 APIView的check_permissions(request)

     def check_permissions(self, request):
         for permission in self.get_permissions():
             # permission是咱们配置在视图类中权限类的对象,对象调用它的绑定方法has_permission
             # 对象调用自己的绑定方法会把自己传入(权限类的对象,request,视图类的对象)
             if not permission.has_permission(request, self):
                 self.permission_denied(
                     request,
                     message=getattr(permission, 'message', None),
                     code=getattr(permission, 'code', None)
                 )
    
  • 读取APIVIew的 self.get_permissions()

         return [permission() for permission in self.permission_classes]
         """self.permission_classes 就是咱们在视图类中配的权限类的列表"""
    

    所以这个get_permissions返回的是 在视图类中配的权限类的对象列表[UserTypePermession(),]

  • 总结

    为什么要写一个类,重写has_permission方法,

    has_permission有三个参数,分别是 权限类的对象,request,视图类的对象

    为什么一定要return True或False:

    返回True是通过权限,Flase是未通过权限

    messgage的作用:

    调用self.messgage可以更改提示信息

频率类源码

  • 读取频率类源码 APIView的check_throttles

     def check_throttles(self, request):
         throttle_durations = []
         for throttle in self.get_throttles():
             if not throttle.allow_request(request, self):
                 throttle_durations.append(throttle.wait())
        
     def get_throttles(self):
         return [throttle() for throttle in self.throttle_classes]
    

    要写频率类,必须重写allow_request方法,然后结束for循环返回值

  • 读取 allow_request 源码

    源码里执行的频率类的allow_request,读SimpleRateThrottle的allow_request

     class SimpleRateThrottle(BaseThrottle):
         cache = default_cache
         timer = time.time
         cache_format = 'throttle_%(scope)s_%(ident)s'
         scope = None
         THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
         def __init__(self):  # 只要类实例化得到对象就会执行,一执行,self.rate就有值了,而且self.num_requests和self.duration
             if not getattr(self, 'rate', None): # 去频率类中反射rate属性或方法,发现没有,返回了None,这个if判断就符合,执行下面的代码
                 self.rate = self.get_rate()  #返回了  '3/m'
             #  self.num_requests=3
             #  self.duration=60
             self.num_requests, self.duration = self.parse_rate(self.rate)
    
         def get_rate(self):
              return self.THROTTLE_RATES[self.scope] # 字典取值,配置文件中咱们配置的字典{'ss': '3/m',},根据ss取到了 '3/m'
    
         def parse_rate(self, rate):
             if rate is None:
                 return (None, None)
             # rate:字符串'3/m'  根据 / 切分,切成了 ['3','m']
             # num=3,period=m
             num, period = rate.split('/')
             # num_requests=3  数字3
             num_requests = int(num)
             # period='m'  ---->period[0]--->'m'
             # {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
             # duration=60
             duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
             # 3     60
             return (num_requests, duration)
    
         def allow_request(self, request, view):
             if self.rate is None:
                 return True
             # 咱们自己写的,返回什么就以什么做限制  咱们返回的是ip地址
             # self.key=当前访问者的ip地址
             self.key = self.get_cache_key(request, view)
             if self.key is None:
                 return True
             # self.history 访问者的时间列表,从缓存中拿到,如果拿不到就是空列表,如果之前有 [时间2,时间1]
             self.history = self.cache.get(self.key, [])
             # 当前时间
             self.now = self.timer()
             while self.history and self.history[-1] <= self.now - self.duration:
                 self.history.pop()
             if len(self.history) >= self.num_requests:
                 return self.throttle_failure()
             return self.throttle_success()
    
  • 总结

    要写频率类,必须重写allow_request方法

    返回True(没有到频率的限制)或False(到了频率的限制)

    以后要再写频率类,只需要继承SimpleRateThrottle,重写get_cache_key,配置类属性scope,配置文件中配置一下就可以了

排序和过滤源码分析

继承了GenericAPIView+ListModelMixin,只要在视图类中配置filter_backends它就能实现过滤和排序

  • drf内置的过滤类(SearchFilter),排序类(OrderingFiler)

     from rest_framework.filters import SearchFilter,OrderingFiler,BaseFilterBackend
    

    在类中只需继承了GenericAPIView+ListModelMixin,配置filter_backends即可直接使用模块的过滤和排序

  • 排序和过滤源码剖析

    排序和过滤只有涉及到查看多个(list)才起作用

    故需继承ListAPIView(继承了GenericAPIView+ListModelMixin)

     ListModelMixin:
       def list(self, request, *args, **kwargs):
           # self.get_queryset()所有数据,经过了self.filter_queryset返回了qs
            # self.filter_queryset完成的过滤
           queryset = self.filter_queryset(self.get_queryset())
              # 如果有分页,走的分页----》视图类中配置了分页类
           page = self.paginate_queryset(queryset)
           if page is not None:
               serializer = self.get_serializer(page, many=True)
               return self.get_paginated_response(serializer.data)
            # 如果没有分页,走正常的序列化,返回
           serializer = self.get_serializer(queryset, many=True)
           return Response(serializer.data)
    

    self.filter_queryset完成了过滤,当前在视图类中,self是视图类的对象

    视图类中没找到去其父类找 ,找到 GenericAPIView 下的 filter_queryset

     def filter_queryset(self, queryset):
         for backend in list(self.filter_backends):
             queryset = backend().filter_queryset(self.request, queryset, self)
             return queryset
    

    重写 filter_queryset

    drf内置的过滤类(SearchFilter),排序类(OrderingFiler)

    即重写过滤类(SearchFilter),排序类(OrderingFiler)内部的filter_queryset方法即可自定义

    排序

     from rest_framework.filters import SearchFilter,OrderingFiler,BaseFilterBackend
    
     class 类名(BaseFilterBackend):
         def filter_queryset(self, request, queryset, view):
            # 重写方法
             return queryset
    
  • 总结:

    写的过滤类要重写filter_queryset,返回qs(过滤或排序后)对象

    后期如果不写过滤类,只要在视图类中重写filter_queryset,在里面实现过滤也可以

restframework-jwt执行流程

restframework-jwt 就是签发流程

本质就是登录接口,为了校验用户是否正确

如果正确签发token,写到了序列化类中,如果不正确返回错误

读取源码的入口:

obtain_jwt_token:核心代码--ObtainJSONWebToken.as_view()

  • ObtainJSONWebToken

    视图类,实现了登录功能

     class ObtainJSONWebToken(JSONWebTokenAPIView):
         serializer_class = JSONWebTokenSerializer
    

    找其父类

     class JSONWebTokenAPIView(APIView):
         # 局部禁用掉权限和认证
         permission_classes = () 
         authentication_classes = ()
    
         def get_serializer_context(self):
             return {
                 'request': self.request,
                 'view': self,
             }
    
         def get_serializer_class(self):
             return self.serializer_class
    
         def get_serializer(self, *args, **kwargs):
             serializer_class = self.get_serializer_class()
             kwargs['context'] = self.get_serializer_context()
             return serializer_class(*args, **kwargs)
    
         def post(self, request, *args, **kwargs):
             # JSONWebTokenSerializer实例化得到一个序列号类的对象,传入前端传的只
             serializer = self.get_serializer(data=request.data)
    
             if serializer.is_valid(): # 校验前端传入的数据是否合法:
                 #1 字段自己的规则 2 局部钩子 3 全局钩子(序列化类的validate方法)
                 # 获取当前登录用户和签发token是在序列化类中完成的
                 # 从序列化类对象中取出了当前登录用户
                 user = serializer.object.get('user') or request.user
                 # # 从序列化类对象中取出了token
                 token = serializer.object.get('token')
                 # 自定义过
                 response_data = jwt_response_payload_handler(token, user, request)
                 response = Response(response_data)
                 return response
    
             return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
    
  • 序列化类 JSONWebTokenSerializer

     class JSONWebTokenSerializer(Serializer):
         def validate(self, attrs):
             credentials = {
                 'username': attrs.get('username'),
                 'password': attrs.get('password')
             }
     
             if all(credentials.values()):
                 # auth的校验用户名和密码是否正确
                 user = authenticate(**credentials)
    
                 if user:
                     # 通过用户获得payload:{}
                     payload = jwt_payload_handler(user)
                     return {
                         'token': jwt_encode_handler(payload),
                         'user': user
                     }
                 else:
                     # 根据用户名和密码查不到用户
                     raise serializers.ValidationError(msg)
                     else:	
                         # 用户名和密码不传,传多了都不行
                         raise serializers.ValidationError(msg)
    
  • 认证类 JSONWebTokenAuthentication

     class JSONWebTokenAuthentication(BaseJSONWebTokenAuthentication):
         def get_jwt_value(self, request):
             # get_authorization_header(request)根据请求头中HTTP_AUTHORIZATION,取出token
             # jwt adsfasdfasdfad
             # auth=['jwt','真正的token']
             auth = get_authorization_header(request).split()
             auth_header_prefix = api_settings.JWT_AUTH_HEADER_PREFIX.lower()
             if not auth:
                 if api_settings.JWT_AUTH_COOKIE:
                     return request.COOKIES.get(api_settings.JWT_AUTH_COOKIE)
                 return None
             if smart_text(auth[0].lower()) != auth_header_prefix:
                 return None
             if len(auth) == 1:
                 msg = _('Invalid Authorization header. No credentials provided.')
                 raise exceptions.AuthenticationFailed(msg)
                 elif len(auth) > 2:
                     msg = _('Invalid Authorization header. Credentials string '
                             'should not contain spaces.')
                     raise exceptions.AuthenticationFailed(msg)
                     return auth[1]
    

    其父类 BaseJSONWebTokenAuthentication---》authenticate

     class BaseJSONWebTokenAuthentication(BaseAuthentication):
         def authenticate(self, request):
             # jwt_value前端传入的token
             jwt_value = self.get_jwt_value(request)
             # 前端没有传入token,return None,没有带token,认证类也能过,所有咱们才加权限类
             if jwt_value is None:
                 return None
             try:
                 payload = jwt_decode_handler(jwt_value) # 验证token,token合法,返回payload
                 except jwt.ExpiredSignature:
                     msg = _('Signature has expired.')
                     raise exceptions.AuthenticationFailed(msg)
                     except jwt.DecodeError:
                         msg = _('Error decoding signature.')
                         raise exceptions.AuthenticationFailed(msg)
                         except jwt.InvalidTokenError:
                             raise exceptions.AuthenticationFailed()
    
                             user = self.authenticate_credentials(payload) # 通过payload得到当前登录用户
    
                             return (user, jwt_value) # 后期的request.user就是当前登录用户
    

    它这个认证类:只要带了token,request.user就有只,如果没带token,不管了,继续往后走

标签:分析,return,get,queryset,self,request,源码,user,drf
From: https://www.cnblogs.com/ioubbu/p/17697765.html

相关文章