首页 > 其他分享 >Django 使用swagger自定义自动生成类

Django 使用swagger自定义自动生成类

时间:2024-01-16 16:25:41浏览次数:29  
标签:serializers description 自定义 get res Django openapi swagger TYPE

完整代码:https://gitee.com/mom925/django-system

之前写的Django配置swagger(https://www.cnblogs.com/moon3496694/p/17657283.html)其实更多还是自己手动的写代码去书写接口文档,我希望它能更加的自动化生成出接口文档,所以我需要自己重写一些函数。
安装所需的包,注册app,注册路由参考之前的即可(https://www.cnblogs.com/moon3496694/p/17657283.html),下面是在之前的基础上做的改进

自定义swagger自动生成的类需要在配置里指定自定义的类

SWAGGER_SETTINGS = {
    'USE_SESSION_AUTH': False,
    'SECURITY_DEFINITIONS': {
        '身份验证': {
            'type': 'apiKey',
            'in': 'header',
            'name': 'Authorization'
        }
    },
    "DEFAULT_AUTO_SCHEMA_CLASS": "utils.swagger.CustomSwaggerAutoSchema",
}
我的swagger.py文件
from django.utils.encoding import smart_str
from drf_yasg.errors import SwaggerGenerationError
from drf_yasg.inspectors import SwaggerAutoSchema
from drf_yasg.utils import merge_params, get_object_classes
from rest_framework.parsers import FileUploadParser
from rest_framework.request import is_form_media_type
from rest_framework.schemas import AutoSchema
from rest_framework.utils import formatting

from Wchime.settings import SWAGGER_SETTINGS


def get_consumes(parser_classes):

    parser_classes = get_object_classes(parser_classes)
    parser_classes = [pc for pc in parser_classes if not issubclass(pc, FileUploadParser)]
    media_types = [parser.media_type for parser in parser_classes or []]
    return media_types


def get_summary(string):
    if string is not None:
        result = string.strip().replace(" ", "").split("\n")
        return result[0]


class CustomAutoSchema(AutoSchema):
    def get_description(self, path, method):
        view = self.view
        return self._get_description_section(view, 'tags', view.get_view_description())


class CustomSwaggerAutoSchema(SwaggerAutoSchema):

    def get_tags(self, operation_keys=None):
        tags = super().get_tags(operation_keys)
        # print(tags)
        if "api" in tags and operation_keys:
            #  `operation_keys` 内容像这样 ['v1', 'prize_join_log', 'create']
            tags[0] = operation_keys[SWAGGER_SETTINGS.get('AUTO_SCHEMA_TYPE', 2)]
        ca = CustomAutoSchema()
        ca.view = self.view
        tag = ca.get_description(self.path, 'get') or None
        if tag:
            # tags.append(tag)
            tags[0] = tag
        # print('===', tags)
        return tags

    def get_summary_and_description(self):
        description = self.overrides.get('operation_description', None)
        summary = self.overrides.get('operation_summary', None)
        # print(description, summary)
        if description is None:
            description = self._sch.get_description(self.path, self.method) or ''
            description = description.strip().replace('\r', '')

            if description and (summary is None):
                # description from docstring... do summary magic
                summary, description = self.split_summary_from_description(description)
            # print('====', summary, description)
        if summary is None:
            summary = description
        return summary, description

    def get_consumes_form(self):

        return get_consumes(self.get_parser_classes())

    def add_manual_parameters(self, parameters):
        """
        重写这个函数,让他能解析json,也可以解析表单
        """
        manual_parameters = self.overrides.get('manual_parameters', None) or []

        if manual_parameters:
            parameters = []

        if any(param.in_ == openapi.IN_BODY for param in manual_parameters):  # pragma: no cover
            raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
        if any(param.in_ == openapi.IN_FORM for param in manual_parameters):  # pragma: no cover
            has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters)

            if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes_form()):
                raise SwaggerGenerationError("cannot add form parameters when the request has a request body; "
                                             "did you forget to set an appropriate parser class on the view?")
            if self.method not in self.body_methods:
                raise SwaggerGenerationError("form parameters can only be applied to "
                                             "(" + ','.join(self.body_methods) + ") HTTP methods")

        return merge_params(parameters, manual_parameters)


# --------------------------------------------------------------------------------------------------------------

from rest_framework import serializers
from drf_yasg import openapi
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.fields import ChoiceField


def serializer_to_swagger(ser_model, get_req=False):
    '''
    序列化转成openapi的形式
    '''
    if ser_model is None and get_req is True:
        return {}, []
    elif ser_model is None and get_req is False:
        return {}
    dit = {}
    serializer_field_mapping = {
        ChoiceField: openapi.TYPE_INTEGER,
        PrimaryKeyRelatedField: openapi.TYPE_INTEGER,
        serializers.IntegerField: openapi.TYPE_INTEGER,
        serializers.BooleanField: openapi.TYPE_BOOLEAN,
        serializers.CharField: openapi.TYPE_STRING,
        serializers.DateField: openapi.TYPE_STRING,
        serializers.DateTimeField: openapi.TYPE_STRING,
        serializers.DecimalField: openapi.TYPE_NUMBER,
        serializers.DurationField: openapi.TYPE_STRING,
        serializers.EmailField: openapi.TYPE_STRING,
        serializers.ModelField: openapi.TYPE_OBJECT,
        serializers.FileField: openapi.TYPE_STRING,
        serializers.FloatField: openapi.TYPE_NUMBER,
        serializers.ImageField: openapi.TYPE_STRING,
        serializers.SlugField: openapi.TYPE_STRING,
        serializers.TimeField: openapi.TYPE_STRING,
        serializers.URLField: openapi.TYPE_STRING,
        serializers.UUIDField: openapi.TYPE_STRING,
        serializers.IPAddressField: openapi.TYPE_STRING,
        serializers.FilePathField: openapi.TYPE_STRING,
    }
    fields = ser_model().get_fields()
    if get_req:
        required = []
        for k, v in fields.items():
            description = getattr(v, 'label', '')
            if isinstance(v, serializers.SerializerMethodField) or getattr(v, 'source'):
                continue
            elif isinstance(v, ChoiceField):
                description += str(dict(getattr(v, 'choices', {})))

            if getattr(v, 'required', True) is not False:
                required.append(k)
            typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING)
            dit[k] = openapi.Schema(description=description, type=typ)
        return dit, required
    else:
        for k, v in fields.items():
            description = getattr(v, 'label', '')
            if isinstance(v, ChoiceField):
                description += str(dict(getattr(v, 'choices', {})))
            elif isinstance(v, serializers.SerializerMethodField):
                continue
            typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING)
            dit[k] = openapi.Schema(description=description, type=typ)

        return dit


def serializer_to_req_form_swagger(ser_model, filter_fields):
    li = list()
    serializer_field_mapping = {
        ChoiceField: openapi.TYPE_INTEGER,
        PrimaryKeyRelatedField: openapi.TYPE_INTEGER,
        serializers.IntegerField: openapi.TYPE_INTEGER,
        serializers.BooleanField: openapi.TYPE_BOOLEAN,
        serializers.CharField: openapi.TYPE_STRING,
        serializers.DateField: openapi.TYPE_STRING,
        serializers.DateTimeField: openapi.TYPE_STRING,
        serializers.DecimalField: openapi.TYPE_NUMBER,
        serializers.DurationField: openapi.TYPE_STRING,
        serializers.EmailField: openapi.TYPE_STRING,
        serializers.ModelField: openapi.TYPE_OBJECT,
        serializers.FileField: openapi.TYPE_FILE,
        serializers.FloatField: openapi.TYPE_NUMBER,
        serializers.ImageField: openapi.TYPE_FILE,
        serializers.SlugField: openapi.TYPE_STRING,
        serializers.TimeField: openapi.TYPE_STRING,
        serializers.URLField: openapi.TYPE_STRING,
        serializers.UUIDField: openapi.TYPE_STRING,
        serializers.IPAddressField: openapi.TYPE_STRING,
        serializers.FilePathField: openapi.TYPE_STRING,
    }
    fields = ser_model().get_fields()
    for k, v in fields.items():
        if k in filter_fields:
            continue
        description = getattr(v, 'label', '')
        if isinstance(v, serializers.SerializerMethodField) or getattr(v, 'source'):
            continue
        elif isinstance(v, ChoiceField):
            description += str(dict(getattr(v, 'choices', {})))
        req = getattr(v, 'required', True)
        typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING)
        li.append(openapi.Parameter(name=k, description=description, type=typ, required=req, in_=openapi.IN_FORM))
    return li


class ViewSwagger(object):

    get_req_params = []
    get_req_body = None
    get_res_data = None
    get_res_examples = {'json': {}}
    get_res_description = ' '
    get_res_code = 200
    get_tags = None
    get_operation_description = None

    post_req_params = []
    post_req_body = None
    post_res_data = None
    post_res_examples = {'json': {}}
    post_res_description = ' '
    post_res_code = 200
    post_tags = None
    post_operation_description = None

    put_req_params = []
    put_req_body = None
    put_res_data = None
    put_res_examples = {'json': {}}
    put_res_description = ' '
    put_res_code = 200
    put_tags = None
    put_operation_description = None

    delete_req_params = []
    delete_req_body = None
    delete_res_data = None
    delete_res_examples = {'json': {}}
    delete_res_description = ' '
    delete_res_code = 200
    delete_tags = None
    delete_operation_description = None

    @classmethod
    def req_serialize_schema(cls, serializer):
        return serializer_to_swagger(serializer, get_req=True)

    @classmethod
    def res_serializer_schema(cls, serializer):
        return serializer_to_swagger(serializer, get_req=False)
    @classmethod
    def req_serializer_form_schema(cls, serializer, filter_fields=[]):
        return serializer_to_req_form_swagger(serializer, filter_fields)
    @classmethod
    def get(cls):

        ret = {
            'manual_parameters': cls.get_req_params,
            'request_body': cls.get_req_body,
            'responses': {cls.get_res_code: openapi.Response(description=cls.get_res_description, schema=cls.get_res_data,  examples=cls.get_res_examples)} if cls.get_res_data else None
        }
        return ret

    @classmethod
    def post(cls):
        ret = {
            'manual_parameters': cls.post_req_params,
            'request_body': cls.post_req_body,
            'responses': {
                cls.post_res_code: openapi.Response(description=cls.post_res_description, schema=cls.post_res_data,
                                                   examples=cls.post_res_examples)} if cls.post_res_data else None
        }
        return ret

    @classmethod
    def put(cls):
        ret = {
            'manual_parameters': cls.put_req_params,
            'request_body': cls.put_req_body,
            'responses': {
                cls.put_res_code: openapi.Response(description=cls.put_res_description, schema=cls.put_res_data,
                                                   examples=cls.put_res_examples)} if cls.put_res_data else None
        }
        return ret

    @classmethod
    def delete(cls):
        ret = {
            'manual_parameters': cls.delete_req_params,
            'request_body': cls.delete_req_body,
            'responses': {
                cls.delete_res_code: openapi.Response(description=cls.delete_res_description, schema=cls.delete_res_data,
                                                   examples=cls.delete_res_examples)} if cls.delete_res_data else None
        }
        return ret

 

首先重写了get_tags方法,我希望只要在视图类下面注释里写上tags:"xxxx"即可自动的读取到。
上面写的CustomAutoSchema类就是读取了视图类的注释,然后获取出里面的tags值
只需要这样写:


然后即可生成:

 得到了都在 测试图片标签下




重写get_summary_and_description方法,原来的这个方法获取到summary是有可能为空的,所以改成当summary为None时summary=description
如果需要在视图类注释中写这两个描述,则像下面一样:


也可以在方法注释中写,则像下面一样:

 


得到的结果一样:


注意如果两个地方都写则里面的注释会覆盖外层的,也就是方法中的注释会去覆盖视图类下面的注释



重写add_manual_parameters方法,原来的自动生成时只能解析一种数据类型,当传入多种解析类型时会默认的是JSON类型(因为rest_framework就是默认解析JSON)
因为在rest_framework中我们不管是表单还是json格式都可以request.data获取,像新增时是提交表单,批量删除时提交json格式,但是一般又写在同一个视图类下
所以给视图类指定解析数据类型 parser_classes = [MultiPartParser, JSONParser]
重写以后,存在两种都有的会返回表单格式先
视图类像下面一样:

 

得到的post和delete:

 

 

得到了post的表达数据和delete的JSON数据

 

 



 

标签:serializers,description,自定义,get,res,Django,openapi,swagger,TYPE
From: https://www.cnblogs.com/moon3496694/p/17967922

相关文章

  • 《标签篇》Vue.directives自定义指令v-my
    参考链接:https://www.runoob.com/vue2/vue-custom-directive.html自定义指令除了默认设置的核心指令(v-model和v-show),Vue也允许注册自定义指令。下面我们注册一个全局指令v-focus,该指令的功能是在页面加载时,元素获得焦点:<divid="app"> <p>页面载入时,input元素自......
  • delphi firemonkey使用 TListView 自定义列表数据
    设计界面如下把ListView的Item的Appearance为DynamicAppearance,并且把Item改为高度100添加Item代码procedureTForm1.Button1Click(Sender:TObject);varimg:TListItemImage;text1,text2,text3:TListItemText;beginvaritem:=ListView1.Items.Add;text......
  • 自定义注解
    importjava.lang.annotation.Documented;importjava.lang.annotation.ElementType;importjava.lang.annotation.Retention;importjava.lang.annotation.RetentionPolicy;importjava.lang.annotation.Target;@Target(ElementType.FIELD)@Retention(RetentionPolic......
  • Django中安装websocket
    完整代码:https://gitee.com/mom925/django-system项目结构:先安装所需库:pipinstallchannels下面将websocket作为插件一样的只需要引入配置的结构asgi.py文件http请求不变、修改websocket请求调用路径importosimportdjangofromchannels.httpimportAsgiHandlerfr......
  • SpringBoot自定义注解实现操作日志记录
    1、增加依赖<dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-aop</artifactId><version>${spring-version}</version>......
  • 自定义echarts绘制直方图,XY轴互调Demo
    1constcolorList=[2'#4f81bd',3'#c0504d',4'#9bbb59',5'#604a7b',6'#948a54',7'#e46c0b'8];9constdata=[10[10,16,3,'A'],11[16,18,15,&#......
  • JSON注解自定义格式解析
    在SpringBoot中,你可以通过自定义注解来格式化或转换属性值。以下是一个示例代码,演示如何实现这个过程:首先,定义一个注解@CustomFormat,用于标注需要格式化或转换的属性。该注解可以包含一个参数,用于指定格式化或转换的方式。importjava.lang.annotation.*;@Target(ElementType......
  • django05
    模板语法之过滤器(类似于内置函数)'''使用的时候可以看看源码'''1.语法结构数据对象|过滤器名称:参数22.常见过滤器(用在html文件里面)(django模板语法提供了60+过滤器,我们了解几个即可)①<p>统计数据的长度:{{s1|length}}</p>②<p>算术加法或者字符串加法:{{s1|add:111}}</......
  • SparkStreaming 自定义数据采集器
    本文的前提条件:SparkStreaminginJava参考地址:SparkStreamingCustomReceivers1.自定义数据采集器packagecn.coreqi.receiver;importorg.apache.spark.storage.StorageLevel;importorg.apache.spark.streaming.receiver.Receiver;importjava.util.Random;/**......
  • Django rest_framework用户认证和权限
    完整的代码https://gitee.com/mom925/django-system使用jwt实现用户认证pipinstalldjangorestframework-simplejwt重新定义一下User类classUsers(AbstractUser):classMeta:db_table="system_users"verbose_name="用户表"......