完整代码:https://gitee.com/mom925/django-system
之前写的Django配置swagger(https://www.cnblogs.com/moon3496694/p/17657283.html)其实更多还是自己手动的写代码去书写接口文档,我希望它能更加的自动化生成出接口文档,所以我需要自己重写一些函数。
安装所需的包,注册app,注册路由参考之前的即可(https://www.cnblogs.com/moon3496694/p/17657283.html),下面是在之前的基础上做的改进
自定义swagger自动生成的类需要在配置里指定自定义的类
SWAGGER_SETTINGS = { 'USE_SESSION_AUTH': False, 'SECURITY_DEFINITIONS': { '身份验证': { 'type': 'apiKey', 'in': 'header', 'name': 'Authorization' } }, "DEFAULT_AUTO_SCHEMA_CLASS": "utils.swagger.CustomSwaggerAutoSchema", }
我的swagger.py文件
from django.utils.encoding import smart_str from drf_yasg.errors import SwaggerGenerationError from drf_yasg.inspectors import SwaggerAutoSchema from drf_yasg.utils import merge_params, get_object_classes from rest_framework.parsers import FileUploadParser from rest_framework.request import is_form_media_type from rest_framework.schemas import AutoSchema from rest_framework.utils import formatting from Wchime.settings import SWAGGER_SETTINGS def get_consumes(parser_classes): parser_classes = get_object_classes(parser_classes) parser_classes = [pc for pc in parser_classes if not issubclass(pc, FileUploadParser)] media_types = [parser.media_type for parser in parser_classes or []] return media_types def get_summary(string): if string is not None: result = string.strip().replace(" ", "").split("\n") return result[0] class CustomAutoSchema(AutoSchema): def get_description(self, path, method): view = self.view return self._get_description_section(view, 'tags', view.get_view_description()) class CustomSwaggerAutoSchema(SwaggerAutoSchema): def get_tags(self, operation_keys=None): tags = super().get_tags(operation_keys) # print(tags) if "api" in tags and operation_keys: # `operation_keys` 内容像这样 ['v1', 'prize_join_log', 'create'] tags[0] = operation_keys[SWAGGER_SETTINGS.get('AUTO_SCHEMA_TYPE', 2)] ca = CustomAutoSchema() ca.view = self.view tag = ca.get_description(self.path, 'get') or None if tag: # tags.append(tag) tags[0] = tag # print('===', tags) return tags def get_summary_and_description(self): description = self.overrides.get('operation_description', None) summary = self.overrides.get('operation_summary', None) # print(description, summary) if description is None: description = self._sch.get_description(self.path, self.method) or '' description = description.strip().replace('\r', '') if description and (summary is None): # description from docstring... do summary magic summary, description = self.split_summary_from_description(description) # print('====', summary, description) if summary is None: summary = description return summary, description def get_consumes_form(self): return get_consumes(self.get_parser_classes()) def add_manual_parameters(self, parameters): """ 重写这个函数,让他能解析json,也可以解析表单 """ manual_parameters = self.overrides.get('manual_parameters', None) or [] if manual_parameters: parameters = [] if any(param.in_ == openapi.IN_BODY for param in manual_parameters): # pragma: no cover raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body") if any(param.in_ == openapi.IN_FORM for param in manual_parameters): # pragma: no cover has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters) if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes_form()): raise SwaggerGenerationError("cannot add form parameters when the request has a request body; " "did you forget to set an appropriate parser class on the view?") if self.method not in self.body_methods: raise SwaggerGenerationError("form parameters can only be applied to " "(" + ','.join(self.body_methods) + ") HTTP methods") return merge_params(parameters, manual_parameters) # -------------------------------------------------------------------------------------------------------------- from rest_framework import serializers from drf_yasg import openapi from rest_framework.relations import PrimaryKeyRelatedField from rest_framework.fields import ChoiceField def serializer_to_swagger(ser_model, get_req=False): ''' 序列化转成openapi的形式 ''' if ser_model is None and get_req is True: return {}, [] elif ser_model is None and get_req is False: return {} dit = {} serializer_field_mapping = { ChoiceField: openapi.TYPE_INTEGER, PrimaryKeyRelatedField: openapi.TYPE_INTEGER, serializers.IntegerField: openapi.TYPE_INTEGER, serializers.BooleanField: openapi.TYPE_BOOLEAN, serializers.CharField: openapi.TYPE_STRING, serializers.DateField: openapi.TYPE_STRING, serializers.DateTimeField: openapi.TYPE_STRING, serializers.DecimalField: openapi.TYPE_NUMBER, serializers.DurationField: openapi.TYPE_STRING, serializers.EmailField: openapi.TYPE_STRING, serializers.ModelField: openapi.TYPE_OBJECT, serializers.FileField: openapi.TYPE_STRING, serializers.FloatField: openapi.TYPE_NUMBER, serializers.ImageField: openapi.TYPE_STRING, serializers.SlugField: openapi.TYPE_STRING, serializers.TimeField: openapi.TYPE_STRING, serializers.URLField: openapi.TYPE_STRING, serializers.UUIDField: openapi.TYPE_STRING, serializers.IPAddressField: openapi.TYPE_STRING, serializers.FilePathField: openapi.TYPE_STRING, } fields = ser_model().get_fields() if get_req: required = [] for k, v in fields.items(): description = getattr(v, 'label', '') if isinstance(v, serializers.SerializerMethodField) or getattr(v, 'source'): continue elif isinstance(v, ChoiceField): description += str(dict(getattr(v, 'choices', {}))) if getattr(v, 'required', True) is not False: required.append(k) typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING) dit[k] = openapi.Schema(description=description, type=typ) return dit, required else: for k, v in fields.items(): description = getattr(v, 'label', '') if isinstance(v, ChoiceField): description += str(dict(getattr(v, 'choices', {}))) elif isinstance(v, serializers.SerializerMethodField): continue typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING) dit[k] = openapi.Schema(description=description, type=typ) return dit def serializer_to_req_form_swagger(ser_model, filter_fields): li = list() serializer_field_mapping = { ChoiceField: openapi.TYPE_INTEGER, PrimaryKeyRelatedField: openapi.TYPE_INTEGER, serializers.IntegerField: openapi.TYPE_INTEGER, serializers.BooleanField: openapi.TYPE_BOOLEAN, serializers.CharField: openapi.TYPE_STRING, serializers.DateField: openapi.TYPE_STRING, serializers.DateTimeField: openapi.TYPE_STRING, serializers.DecimalField: openapi.TYPE_NUMBER, serializers.DurationField: openapi.TYPE_STRING, serializers.EmailField: openapi.TYPE_STRING, serializers.ModelField: openapi.TYPE_OBJECT, serializers.FileField: openapi.TYPE_FILE, serializers.FloatField: openapi.TYPE_NUMBER, serializers.ImageField: openapi.TYPE_FILE, serializers.SlugField: openapi.TYPE_STRING, serializers.TimeField: openapi.TYPE_STRING, serializers.URLField: openapi.TYPE_STRING, serializers.UUIDField: openapi.TYPE_STRING, serializers.IPAddressField: openapi.TYPE_STRING, serializers.FilePathField: openapi.TYPE_STRING, } fields = ser_model().get_fields() for k, v in fields.items(): if k in filter_fields: continue description = getattr(v, 'label', '') if isinstance(v, serializers.SerializerMethodField) or getattr(v, 'source'): continue elif isinstance(v, ChoiceField): description += str(dict(getattr(v, 'choices', {}))) req = getattr(v, 'required', True) typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING) li.append(openapi.Parameter(name=k, description=description, type=typ, required=req, in_=openapi.IN_FORM)) return li class ViewSwagger(object): get_req_params = [] get_req_body = None get_res_data = None get_res_examples = {'json': {}} get_res_description = ' ' get_res_code = 200 get_tags = None get_operation_description = None post_req_params = [] post_req_body = None post_res_data = None post_res_examples = {'json': {}} post_res_description = ' ' post_res_code = 200 post_tags = None post_operation_description = None put_req_params = [] put_req_body = None put_res_data = None put_res_examples = {'json': {}} put_res_description = ' ' put_res_code = 200 put_tags = None put_operation_description = None delete_req_params = [] delete_req_body = None delete_res_data = None delete_res_examples = {'json': {}} delete_res_description = ' ' delete_res_code = 200 delete_tags = None delete_operation_description = None @classmethod def req_serialize_schema(cls, serializer): return serializer_to_swagger(serializer, get_req=True) @classmethod def res_serializer_schema(cls, serializer): return serializer_to_swagger(serializer, get_req=False) @classmethod def req_serializer_form_schema(cls, serializer, filter_fields=[]): return serializer_to_req_form_swagger(serializer, filter_fields) @classmethod def get(cls): ret = { 'manual_parameters': cls.get_req_params, 'request_body': cls.get_req_body, 'responses': {cls.get_res_code: openapi.Response(description=cls.get_res_description, schema=cls.get_res_data, examples=cls.get_res_examples)} if cls.get_res_data else None } return ret @classmethod def post(cls): ret = { 'manual_parameters': cls.post_req_params, 'request_body': cls.post_req_body, 'responses': { cls.post_res_code: openapi.Response(description=cls.post_res_description, schema=cls.post_res_data, examples=cls.post_res_examples)} if cls.post_res_data else None } return ret @classmethod def put(cls): ret = { 'manual_parameters': cls.put_req_params, 'request_body': cls.put_req_body, 'responses': { cls.put_res_code: openapi.Response(description=cls.put_res_description, schema=cls.put_res_data, examples=cls.put_res_examples)} if cls.put_res_data else None } return ret @classmethod def delete(cls): ret = { 'manual_parameters': cls.delete_req_params, 'request_body': cls.delete_req_body, 'responses': { cls.delete_res_code: openapi.Response(description=cls.delete_res_description, schema=cls.delete_res_data, examples=cls.delete_res_examples)} if cls.delete_res_data else None } return ret
首先重写了get_tags方法,我希望只要在视图类下面注释里写上tags:"xxxx"即可自动的读取到。
上面写的CustomAutoSchema类就是读取了视图类的注释,然后获取出里面的tags值
只需要这样写:
然后即可生成:
得到了都在 测试图片标签下
重写get_summary_and_description方法,原来的这个方法获取到summary是有可能为空的,所以改成当summary为None时summary=description
如果需要在视图类注释中写这两个描述,则像下面一样:
也可以在方法注释中写,则像下面一样:
得到的结果一样:
注意如果两个地方都写则里面的注释会覆盖外层的,也就是方法中的注释会去覆盖视图类下面的注释
重写add_manual_parameters方法,原来的自动生成时只能解析一种数据类型,当传入多种解析类型时会默认的是JSON类型(因为rest_framework就是默认解析JSON)
因为在rest_framework中我们不管是表单还是json格式都可以request.data获取,像新增时是提交表单,批量删除时提交json格式,但是一般又写在同一个视图类下
所以给视图类指定解析数据类型 parser_classes = [MultiPartParser, JSONParser]
重写以后,存在两种都有的会返回表单格式先
视图类像下面一样:
得到的post和delete:
得到了post的表达数据和delete的JSON数据
标签:serializers,description,自定义,get,res,Django,openapi,swagger,TYPE From: https://www.cnblogs.com/moon3496694/p/17967922