Spaces:
Paused
Paused
| import http.client | |
| import inspect | |
| import warnings | |
| from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast | |
| from fastapi import routing | |
| from fastapi._compat import ( | |
| GenerateJsonSchema, | |
| JsonSchemaValue, | |
| ModelField, | |
| Undefined, | |
| get_compat_model_name_map, | |
| get_definitions, | |
| get_schema_from_model_field, | |
| lenient_issubclass, | |
| ) | |
| from fastapi.datastructures import DefaultPlaceholder | |
| from fastapi.dependencies.models import Dependant | |
| from fastapi.dependencies.utils import get_flat_dependant, get_flat_params | |
| from fastapi.encoders import jsonable_encoder | |
| from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE | |
| from fastapi.openapi.models import OpenAPI | |
| from fastapi.params import Body, Param | |
| from fastapi.responses import Response | |
| from fastapi.types import ModelNameMap | |
| from fastapi.utils import ( | |
| deep_dict_update, | |
| generate_operation_id_for_path, | |
| is_body_allowed_for_status_code, | |
| ) | |
| from starlette.responses import JSONResponse | |
| from starlette.routing import BaseRoute | |
| from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY | |
| from typing_extensions import Literal | |
| validation_error_definition = { | |
| "title": "ValidationError", | |
| "type": "object", | |
| "properties": { | |
| "loc": { | |
| "title": "Location", | |
| "type": "array", | |
| "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, | |
| }, | |
| "msg": {"title": "Message", "type": "string"}, | |
| "type": {"title": "Error Type", "type": "string"}, | |
| }, | |
| "required": ["loc", "msg", "type"], | |
| } | |
| validation_error_response_definition = { | |
| "title": "HTTPValidationError", | |
| "type": "object", | |
| "properties": { | |
| "detail": { | |
| "title": "Detail", | |
| "type": "array", | |
| "items": {"$ref": REF_PREFIX + "ValidationError"}, | |
| } | |
| }, | |
| } | |
| status_code_ranges: Dict[str, str] = { | |
| "1XX": "Information", | |
| "2XX": "Success", | |
| "3XX": "Redirection", | |
| "4XX": "Client Error", | |
| "5XX": "Server Error", | |
| "DEFAULT": "Default Response", | |
| } | |
| def get_openapi_security_definitions( | |
| flat_dependant: Dependant, | |
| ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: | |
| security_definitions = {} | |
| operation_security = [] | |
| for security_requirement in flat_dependant.security_requirements: | |
| security_definition = jsonable_encoder( | |
| security_requirement.security_scheme.model, | |
| by_alias=True, | |
| exclude_none=True, | |
| ) | |
| security_name = security_requirement.security_scheme.scheme_name | |
| security_definitions[security_name] = security_definition | |
| operation_security.append({security_name: security_requirement.scopes}) | |
| return security_definitions, operation_security | |
| def get_openapi_operation_parameters( | |
| *, | |
| all_route_params: Sequence[ModelField], | |
| schema_generator: GenerateJsonSchema, | |
| model_name_map: ModelNameMap, | |
| field_mapping: Dict[ | |
| Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue | |
| ], | |
| separate_input_output_schemas: bool = True, | |
| ) -> List[Dict[str, Any]]: | |
| parameters = [] | |
| for param in all_route_params: | |
| field_info = param.field_info | |
| field_info = cast(Param, field_info) | |
| if not field_info.include_in_schema: | |
| continue | |
| param_schema = get_schema_from_model_field( | |
| field=param, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| field_mapping=field_mapping, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| parameter = { | |
| "name": param.alias, | |
| "in": field_info.in_.value, | |
| "required": param.required, | |
| "schema": param_schema, | |
| } | |
| if field_info.description: | |
| parameter["description"] = field_info.description | |
| if field_info.openapi_examples: | |
| parameter["examples"] = jsonable_encoder(field_info.openapi_examples) | |
| elif field_info.example != Undefined: | |
| parameter["example"] = jsonable_encoder(field_info.example) | |
| if field_info.deprecated: | |
| parameter["deprecated"] = True | |
| parameters.append(parameter) | |
| return parameters | |
| def get_openapi_operation_request_body( | |
| *, | |
| body_field: Optional[ModelField], | |
| schema_generator: GenerateJsonSchema, | |
| model_name_map: ModelNameMap, | |
| field_mapping: Dict[ | |
| Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue | |
| ], | |
| separate_input_output_schemas: bool = True, | |
| ) -> Optional[Dict[str, Any]]: | |
| if not body_field: | |
| return None | |
| assert isinstance(body_field, ModelField) | |
| body_schema = get_schema_from_model_field( | |
| field=body_field, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| field_mapping=field_mapping, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| field_info = cast(Body, body_field.field_info) | |
| request_media_type = field_info.media_type | |
| required = body_field.required | |
| request_body_oai: Dict[str, Any] = {} | |
| if required: | |
| request_body_oai["required"] = required | |
| request_media_content: Dict[str, Any] = {"schema": body_schema} | |
| if field_info.openapi_examples: | |
| request_media_content["examples"] = jsonable_encoder( | |
| field_info.openapi_examples | |
| ) | |
| elif field_info.example != Undefined: | |
| request_media_content["example"] = jsonable_encoder(field_info.example) | |
| request_body_oai["content"] = {request_media_type: request_media_content} | |
| return request_body_oai | |
| def generate_operation_id( | |
| *, route: routing.APIRoute, method: str | |
| ) -> str: # pragma: nocover | |
| warnings.warn( | |
| "fastapi.openapi.utils.generate_operation_id() was deprecated, " | |
| "it is not used internally, and will be removed soon", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| if route.operation_id: | |
| return route.operation_id | |
| path: str = route.path_format | |
| return generate_operation_id_for_path(name=route.name, path=path, method=method) | |
| def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str: | |
| if route.summary: | |
| return route.summary | |
| return route.name.replace("_", " ").title() | |
| def get_openapi_operation_metadata( | |
| *, route: routing.APIRoute, method: str, operation_ids: Set[str] | |
| ) -> Dict[str, Any]: | |
| operation: Dict[str, Any] = {} | |
| if route.tags: | |
| operation["tags"] = route.tags | |
| operation["summary"] = generate_operation_summary(route=route, method=method) | |
| if route.description: | |
| operation["description"] = route.description | |
| operation_id = route.operation_id or route.unique_id | |
| if operation_id in operation_ids: | |
| message = ( | |
| f"Duplicate Operation ID {operation_id} for function " | |
| + f"{route.endpoint.__name__}" | |
| ) | |
| file_name = getattr(route.endpoint, "__globals__", {}).get("__file__") | |
| if file_name: | |
| message += f" at {file_name}" | |
| warnings.warn(message, stacklevel=1) | |
| operation_ids.add(operation_id) | |
| operation["operationId"] = operation_id | |
| if route.deprecated: | |
| operation["deprecated"] = route.deprecated | |
| return operation | |
| def get_openapi_path( | |
| *, | |
| route: routing.APIRoute, | |
| operation_ids: Set[str], | |
| schema_generator: GenerateJsonSchema, | |
| model_name_map: ModelNameMap, | |
| field_mapping: Dict[ | |
| Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue | |
| ], | |
| separate_input_output_schemas: bool = True, | |
| ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: | |
| path = {} | |
| security_schemes: Dict[str, Any] = {} | |
| definitions: Dict[str, Any] = {} | |
| assert route.methods is not None, "Methods must be a list" | |
| if isinstance(route.response_class, DefaultPlaceholder): | |
| current_response_class: Type[Response] = route.response_class.value | |
| else: | |
| current_response_class = route.response_class | |
| assert current_response_class, "A response class is needed to generate OpenAPI" | |
| route_response_media_type: Optional[str] = current_response_class.media_type | |
| if route.include_in_schema: | |
| for method in route.methods: | |
| operation = get_openapi_operation_metadata( | |
| route=route, method=method, operation_ids=operation_ids | |
| ) | |
| parameters: List[Dict[str, Any]] = [] | |
| flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) | |
| security_definitions, operation_security = get_openapi_security_definitions( | |
| flat_dependant=flat_dependant | |
| ) | |
| if operation_security: | |
| operation.setdefault("security", []).extend(operation_security) | |
| if security_definitions: | |
| security_schemes.update(security_definitions) | |
| all_route_params = get_flat_params(route.dependant) | |
| operation_parameters = get_openapi_operation_parameters( | |
| all_route_params=all_route_params, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| field_mapping=field_mapping, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| parameters.extend(operation_parameters) | |
| if parameters: | |
| all_parameters = { | |
| (param["in"], param["name"]): param for param in parameters | |
| } | |
| required_parameters = { | |
| (param["in"], param["name"]): param | |
| for param in parameters | |
| if param.get("required") | |
| } | |
| # Make sure required definitions of the same parameter take precedence | |
| # over non-required definitions | |
| all_parameters.update(required_parameters) | |
| operation["parameters"] = list(all_parameters.values()) | |
| if method in METHODS_WITH_BODY: | |
| request_body_oai = get_openapi_operation_request_body( | |
| body_field=route.body_field, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| field_mapping=field_mapping, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| if request_body_oai: | |
| operation["requestBody"] = request_body_oai | |
| if route.callbacks: | |
| callbacks = {} | |
| for callback in route.callbacks: | |
| if isinstance(callback, routing.APIRoute): | |
| ( | |
| cb_path, | |
| cb_security_schemes, | |
| cb_definitions, | |
| ) = get_openapi_path( | |
| route=callback, | |
| operation_ids=operation_ids, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| field_mapping=field_mapping, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| callbacks[callback.name] = {callback.path: cb_path} | |
| operation["callbacks"] = callbacks | |
| if route.status_code is not None: | |
| status_code = str(route.status_code) | |
| else: | |
| # It would probably make more sense for all response classes to have an | |
| # explicit default status_code, and to extract it from them, instead of | |
| # doing this inspection tricks, that would probably be in the future | |
| # TODO: probably make status_code a default class attribute for all | |
| # responses in Starlette | |
| response_signature = inspect.signature(current_response_class.__init__) | |
| status_code_param = response_signature.parameters.get("status_code") | |
| if status_code_param is not None: | |
| if isinstance(status_code_param.default, int): | |
| status_code = str(status_code_param.default) | |
| operation.setdefault("responses", {}).setdefault(status_code, {})[ | |
| "description" | |
| ] = route.response_description | |
| if route_response_media_type and is_body_allowed_for_status_code( | |
| route.status_code | |
| ): | |
| response_schema = {"type": "string"} | |
| if lenient_issubclass(current_response_class, JSONResponse): | |
| if route.response_field: | |
| response_schema = get_schema_from_model_field( | |
| field=route.response_field, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| field_mapping=field_mapping, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| else: | |
| response_schema = {} | |
| operation.setdefault("responses", {}).setdefault( | |
| status_code, {} | |
| ).setdefault("content", {}).setdefault(route_response_media_type, {})[ | |
| "schema" | |
| ] = response_schema | |
| if route.responses: | |
| operation_responses = operation.setdefault("responses", {}) | |
| for ( | |
| additional_status_code, | |
| additional_response, | |
| ) in route.responses.items(): | |
| process_response = additional_response.copy() | |
| process_response.pop("model", None) | |
| status_code_key = str(additional_status_code).upper() | |
| if status_code_key == "DEFAULT": | |
| status_code_key = "default" | |
| openapi_response = operation_responses.setdefault( | |
| status_code_key, {} | |
| ) | |
| assert isinstance( | |
| process_response, dict | |
| ), "An additional response must be a dict" | |
| field = route.response_fields.get(additional_status_code) | |
| additional_field_schema: Optional[Dict[str, Any]] = None | |
| if field: | |
| additional_field_schema = get_schema_from_model_field( | |
| field=field, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| field_mapping=field_mapping, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| media_type = route_response_media_type or "application/json" | |
| additional_schema = ( | |
| process_response.setdefault("content", {}) | |
| .setdefault(media_type, {}) | |
| .setdefault("schema", {}) | |
| ) | |
| deep_dict_update(additional_schema, additional_field_schema) | |
| status_text: Optional[str] = status_code_ranges.get( | |
| str(additional_status_code).upper() | |
| ) or http.client.responses.get(int(additional_status_code)) | |
| description = ( | |
| process_response.get("description") | |
| or openapi_response.get("description") | |
| or status_text | |
| or "Additional Response" | |
| ) | |
| deep_dict_update(openapi_response, process_response) | |
| openapi_response["description"] = description | |
| http422 = str(HTTP_422_UNPROCESSABLE_ENTITY) | |
| if (all_route_params or route.body_field) and not any( | |
| status in operation["responses"] | |
| for status in [http422, "4XX", "default"] | |
| ): | |
| operation["responses"][http422] = { | |
| "description": "Validation Error", | |
| "content": { | |
| "application/json": { | |
| "schema": {"$ref": REF_PREFIX + "HTTPValidationError"} | |
| } | |
| }, | |
| } | |
| if "ValidationError" not in definitions: | |
| definitions.update( | |
| { | |
| "ValidationError": validation_error_definition, | |
| "HTTPValidationError": validation_error_response_definition, | |
| } | |
| ) | |
| if route.openapi_extra: | |
| deep_dict_update(operation, route.openapi_extra) | |
| path[method.lower()] = operation | |
| return path, security_schemes, definitions | |
| def get_fields_from_routes( | |
| routes: Sequence[BaseRoute], | |
| ) -> List[ModelField]: | |
| body_fields_from_routes: List[ModelField] = [] | |
| responses_from_routes: List[ModelField] = [] | |
| request_fields_from_routes: List[ModelField] = [] | |
| callback_flat_models: List[ModelField] = [] | |
| for route in routes: | |
| if getattr(route, "include_in_schema", None) and isinstance( | |
| route, routing.APIRoute | |
| ): | |
| if route.body_field: | |
| assert isinstance( | |
| route.body_field, ModelField | |
| ), "A request body must be a Pydantic Field" | |
| body_fields_from_routes.append(route.body_field) | |
| if route.response_field: | |
| responses_from_routes.append(route.response_field) | |
| if route.response_fields: | |
| responses_from_routes.extend(route.response_fields.values()) | |
| if route.callbacks: | |
| callback_flat_models.extend(get_fields_from_routes(route.callbacks)) | |
| params = get_flat_params(route.dependant) | |
| request_fields_from_routes.extend(params) | |
| flat_models = callback_flat_models + list( | |
| body_fields_from_routes + responses_from_routes + request_fields_from_routes | |
| ) | |
| return flat_models | |
| def get_openapi( | |
| *, | |
| title: str, | |
| version: str, | |
| openapi_version: str = "3.1.0", | |
| summary: Optional[str] = None, | |
| description: Optional[str] = None, | |
| routes: Sequence[BaseRoute], | |
| webhooks: Optional[Sequence[BaseRoute]] = None, | |
| tags: Optional[List[Dict[str, Any]]] = None, | |
| servers: Optional[List[Dict[str, Union[str, Any]]]] = None, | |
| terms_of_service: Optional[str] = None, | |
| contact: Optional[Dict[str, Union[str, Any]]] = None, | |
| license_info: Optional[Dict[str, Union[str, Any]]] = None, | |
| separate_input_output_schemas: bool = True, | |
| ) -> Dict[str, Any]: | |
| info: Dict[str, Any] = {"title": title, "version": version} | |
| if summary: | |
| info["summary"] = summary | |
| if description: | |
| info["description"] = description | |
| if terms_of_service: | |
| info["termsOfService"] = terms_of_service | |
| if contact: | |
| info["contact"] = contact | |
| if license_info: | |
| info["license"] = license_info | |
| output: Dict[str, Any] = {"openapi": openapi_version, "info": info} | |
| if servers: | |
| output["servers"] = servers | |
| components: Dict[str, Dict[str, Any]] = {} | |
| paths: Dict[str, Dict[str, Any]] = {} | |
| webhook_paths: Dict[str, Dict[str, Any]] = {} | |
| operation_ids: Set[str] = set() | |
| all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) | |
| model_name_map = get_compat_model_name_map(all_fields) | |
| schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) | |
| field_mapping, definitions = get_definitions( | |
| fields=all_fields, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| for route in routes or []: | |
| if isinstance(route, routing.APIRoute): | |
| result = get_openapi_path( | |
| route=route, | |
| operation_ids=operation_ids, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| field_mapping=field_mapping, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| if result: | |
| path, security_schemes, path_definitions = result | |
| if path: | |
| paths.setdefault(route.path_format, {}).update(path) | |
| if security_schemes: | |
| components.setdefault("securitySchemes", {}).update( | |
| security_schemes | |
| ) | |
| if path_definitions: | |
| definitions.update(path_definitions) | |
| for webhook in webhooks or []: | |
| if isinstance(webhook, routing.APIRoute): | |
| result = get_openapi_path( | |
| route=webhook, | |
| operation_ids=operation_ids, | |
| schema_generator=schema_generator, | |
| model_name_map=model_name_map, | |
| field_mapping=field_mapping, | |
| separate_input_output_schemas=separate_input_output_schemas, | |
| ) | |
| if result: | |
| path, security_schemes, path_definitions = result | |
| if path: | |
| webhook_paths.setdefault(webhook.path_format, {}).update(path) | |
| if security_schemes: | |
| components.setdefault("securitySchemes", {}).update( | |
| security_schemes | |
| ) | |
| if path_definitions: | |
| definitions.update(path_definitions) | |
| if definitions: | |
| components["schemas"] = {k: definitions[k] for k in sorted(definitions)} | |
| if components: | |
| output["components"] = components | |
| output["paths"] = paths | |
| if webhook_paths: | |
| output["webhooks"] = webhook_paths | |
| if tags: | |
| output["tags"] = tags | |
| return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore | |