alessandro trinca tornidor
				
			
		feat: move driver.js to a git submodule; remove unused postgis deps; update some broken python test cases
		dbcbde8
		
		| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Callable, NoReturn | |
| from asgi_correlation_id import CorrelationIdMiddleware | |
| import gradio as gr | |
| from starlette.responses import JSONResponse | |
| import structlog | |
| import uvicorn | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, HTTPException, Request, status | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.responses import FileResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from pydantic import ValidationError | |
| from samgis_core.utilities import create_folders_if_not_exists | |
| from samgis_core.utilities import frontend_builder | |
| from samgis_core.utilities.session_logger import setup_logging | |
| from samgis_web.utilities.constants import GRADIO_EXAMPLES_TEXT_LIST, GRADIO_MARKDOWN, GRADIO_EXAMPLE_BODY_STRING_PROMPT | |
| from samgis_web.utilities.type_hints import StringPromptApiRequestBody | |
| from samgis_lisa.io_package.wrappers_helpers import get_parsed_bbox_points_with_string_prompt | |
| from samgis_lisa.prediction_api import lisa as lisa_module | |
| from samgis_lisa.utilities.constants import LISA_INFERENCE_FN | |
| load_dotenv() | |
| project_root_folder = Path(globals().get("__file__", "./_")).absolute().parent | |
| workdir = Path(os.getenv("WORKDIR", project_root_folder)) | |
| model_folder = Path(project_root_folder / "machine_learning_models") | |
| log_level = os.getenv("LOG_LEVEL", "INFO") | |
| setup_logging(log_level=log_level) | |
| app_logger = structlog.stdlib.get_logger() | |
| app_logger.info(f"PROJECT_ROOT_FOLDER:{project_root_folder}, WORKDIR:{workdir}.") | |
| folders_map = os.getenv("FOLDERS_MAP", "{}") | |
| markdown_text = os.getenv("MARKDOWN_TEXT", "") | |
| examples_text_list = os.getenv("EXAMPLES_TEXT_LIST", "").split("\n") | |
| example_body = json.loads(os.getenv("EXAMPLE_BODY", "{}")) | |
| mount_gradio_app = bool(os.getenv("MOUNT_GRADIO_APP", "")) | |
| static_dist_folder = workdir / "static" / "dist" | |
| input_css_path = os.getenv("INPUT_CSS_PATH", "src/input.css") | |
| vite_gradio_url = os.getenv("VITE_GRADIO_URL", "/gradio") | |
| vite_index_url = os.getenv("VITE_INDEX_URL", "/") | |
| vite_samgis_url = os.getenv("VITE_SAMGIS_URL", "/samgis") | |
| vite_lisa_url = os.getenv("VITE_LISA_URL", "/lisa") | |
| fastapi_title = "samgis-lisa-on-cuda" | |
| app = FastAPI(title=fastapi_title, version="1.0") | |
| async def request_middleware(request, call_next): | |
| from samgis_web.web.middlewares import logging_middleware | |
| return await logging_middleware(request, call_next) | |
| def get_example_complete(example_text): | |
| example_dict = dict(**GRADIO_EXAMPLE_BODY_STRING_PROMPT) | |
| example_dict["string_prompt"] = example_text | |
| return json.dumps(example_dict) | |
| def get_gradio_interface_geojson(fn_inference: Callable): | |
| with gr.Blocks() as gradio_app: | |
| gr.Markdown(GRADIO_MARKDOWN) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox(lines=1, placeholder=None, label="Payload input") | |
| btn = gr.Button(value="Submit") | |
| with gr.Column(): | |
| text_output = gr.Textbox(lines=1, placeholder=None, label="Geojson Output") | |
| gr.Examples( | |
| examples=[ | |
| get_example_complete(example) for example in GRADIO_EXAMPLES_TEXT_LIST | |
| ], | |
| inputs=[text_input], | |
| ) | |
| btn.click( | |
| fn_inference, | |
| inputs=[text_input], | |
| outputs=[text_output] | |
| ) | |
| return gradio_app | |
| def handle_exception_response(exception: Exception) -> NoReturn: | |
| import subprocess | |
| project_root_folder_content = subprocess.run( | |
| f"ls -l {project_root_folder}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE | |
| ) | |
| app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.") | |
| workdir_folder_content = subprocess.run( | |
| f"ls -l {workdir}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE | |
| ) | |
| app_logger.error(f"workdir folder 'ls -l' command stdout: {workdir_folder_content.stdout}.") | |
| app_logger.error(f"workdir folder 'ls -l' command stderr: {workdir_folder_content.stderr}.") | |
| app_logger.error(f"inference error:{exception}.") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference" | |
| ) | |
| async def health() -> JSONResponse: | |
| from samgis_web.__version__ import __version__ as version_web | |
| from samgis_core.__version__ import __version__ as version_core | |
| from lisa_on_cuda.__version__ import __version__ as version_lisa_on_cuda | |
| from samgis_lisa.__version__ import __version__ as version_samgis_lisa | |
| app_logger.info(f"still alive, version_web:{version_web}, version_core:{version_core}.") | |
| app_logger.info(f"still alive, version_lisa_on_cuda:{version_lisa_on_cuda}, version_samgis_lisa:{version_samgis_lisa}.") | |
| return JSONResponse(status_code=200, content={"msg": "still alive..."}) | |
| async def health_models() -> JSONResponse: | |
| from samgis_lisa.prediction_api import lisa | |
| from samgis_lisa.utilities.constants import LISA_INFERENCE_FN | |
| from samgis_web.__version__ import __version__ as version_web | |
| from samgis_core.__version__ import __version__ as version_core | |
| from lisa_on_cuda.__version__ import __version__ as version_lisa_on_cuda | |
| from samgis_lisa.__version__ import __version__ as version_samgis_lisa | |
| from samgis_lisa.prediction_api.global_models import models_dict | |
| app_logger.info(f"still alive, version_web:{version_web}, version_core:{version_core}.") | |
| app_logger.info(f"still alive, version_lisa_on_cuda:{version_lisa_on_cuda}, version_samgis_lisa:{version_samgis_lisa}.") | |
| app_logger.info(f"try to load inference function for '{LISA_INFERENCE_FN}' model...") | |
| if models_dict[LISA_INFERENCE_FN]["inference"] is None: | |
| app_logger.info(f"model not found, loading inference function for '{LISA_INFERENCE_FN}' model. This could take some minutes...") | |
| lisa.load_model_and_inference_fn(LISA_INFERENCE_FN, inference_decorator=None, device_map="auto", device="cuda") | |
| try: | |
| model_name = models_dict[LISA_INFERENCE_FN]["inference"] | |
| app_logger.info(f"inference function for '{LISA_INFERENCE_FN}' model => '{model_name.__name__}' found and loaded...") | |
| except KeyError as ke: | |
| app_logger.error(f"model not found, error:{ke}.") | |
| raise HTTPException(status_code=500, detail="Internal Server Error") | |
| return JSONResponse(status_code=200, content={"msg": f"still alive, inference function for '{LISA_INFERENCE_FN}' model loaded..."}) | |
| def infer_lisa_gradio(request_input: StringPromptApiRequestBody) -> str: | |
| app_logger.info("starting lisa inference request...") | |
| try: | |
| import time | |
| time_start_run = time.time() | |
| body_request = get_parsed_bbox_points_with_string_prompt(request_input) | |
| app_logger.info(f"lisa body_request:{body_request}.") | |
| try: | |
| source = body_request["source"] | |
| source_name = body_request["source_name"] | |
| app_logger.debug(f"body_request:type(source):{type(source)}, source:{source}.") | |
| app_logger.debug(f"body_request:type(source_name):{type(source_name)}, source_name:{source_name}.") | |
| app_logger.debug(f"lisa module:{lisa}.") | |
| output = lisa_module.lisa_predict( | |
| bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], | |
| source=source, source_name=source_name, inference_function_name_key=LISA_INFERENCE_FN | |
| ) | |
| duration_run = time.time() - time_start_run | |
| app_logger.info(f"duration_run:{duration_run}.") | |
| body = { | |
| "duration_run": duration_run, | |
| "output": output | |
| } | |
| dumped = json.dumps(body) | |
| app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.") | |
| app_logger.debug(f"complete json.dumps(body):{dumped}.") | |
| return dumped | |
| except Exception as inference_exception: | |
| app_logger.error(f"inference_exception:{inference_exception}.") | |
| app_logger.error(f"inference_exception, request_input:{request_input}.") | |
| raise HTTPException(status_code=500, detail="Internal Server Error") | |
| except ValidationError as va1: | |
| app_logger.error(f"validation error: {str(va1)}.") | |
| app_logger.error(f"ValidationError, request_input:{request_input}.") | |
| raise RequestValidationError("Unprocessable Entity") | |
| def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse: | |
| dumped = infer_lisa_gradio(request_input=request_input) | |
| app_logger.info(f"json.dumps(body) type:{type(dumped)}, len:{len(dumped)}.") | |
| app_logger.debug(f"complete json.dumps(body):{dumped}.") | |
| return JSONResponse(status_code=200, content={"body": dumped}) | |
| def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: | |
| from samgis_web.web import exception_handlers | |
| return exception_handlers.request_validation_exception_handler(request, exc) | |
| def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: | |
| from samgis_web.web import exception_handlers | |
| return exception_handlers.http_exception_handler(request, exc) | |
| create_folders_if_not_exists.folders_creation(folders_map) | |
| write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "") | |
| app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.") | |
| if bool(write_tmp_on_disk): | |
| try: | |
| assert Path(write_tmp_on_disk).is_dir() | |
| app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output") | |
| templates = Jinja2Templates(directory=str(project_root_folder / "static")) | |
| def list_files(request: Request): | |
| files = os.listdir(write_tmp_on_disk) | |
| files_paths = sorted([f"{request.url._url}/{f}" for f in files]) | |
| print(files_paths) | |
| return templates.TemplateResponse( | |
| "list_files.html", {"request": request, "files": files_paths} | |
| ) | |
| except (AssertionError, RuntimeError) as rerr: | |
| app_logger.error(f"{rerr} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...") | |
| raise rerr | |
| frontend_builder.build_frontend( | |
| project_root_folder=workdir, | |
| input_css_path=input_css_path, | |
| output_dist_folder=static_dist_folder | |
| ) | |
| app_logger.info("build_frontend ok!") | |
| templates = Jinja2Templates(directory="templates") | |
| app.mount("/static", StaticFiles(directory=static_dist_folder, html=True), name="static") | |
| # important: the index() function and the app.mount MUST be at the end | |
| # samgis.html | |
| app.mount(vite_samgis_url, StaticFiles(directory=static_dist_folder, html=True), name="samgis") | |
| async def samgis() -> FileResponse: | |
| return FileResponse(path=str(static_dist_folder / "samgis.html"), media_type="text/html") | |
| # lisa.html | |
| app.mount(vite_lisa_url, StaticFiles(directory=static_dist_folder, html=True), name="lisa") | |
| async def lisa() -> FileResponse: | |
| return FileResponse(path=str(static_dist_folder / "lisa.html"), media_type="text/html") | |
| # index.html (lisa.html copy) | |
| app.mount(vite_index_url, StaticFiles(directory=static_dist_folder, html=True), name="index") | |
| async def index() -> FileResponse: | |
| return FileResponse(path=str(static_dist_folder / "index.html"), media_type="text/html") | |
| app_logger.info("creating gradio interface...") | |
| gr_interface = get_gradio_interface_geojson(infer_lisa_gradio) | |
| app_logger.info(f"gradio interface created, mounting gradio app on url {vite_gradio_url} within FastAPI...") | |
| app = gr.mount_gradio_app(app, gr_interface, path=vite_gradio_url) | |
| app_logger.info("mounted gradio app within fastapi") | |
| # add the CorrelationIdMiddleware AFTER the @app.middleware("http") decorated function to avoid missing request id | |
| app.add_middleware(CorrelationIdMiddleware) | |
| if __name__ == '__main__': | |
| try: | |
| uvicorn.run(host="0.0.0.0", port=7860, app=app) | |
| except Exception as ex: | |
| app_logger.error(f"fastapi/gradio application {fastapi_title}, exception:{ex}.") | |
| print(f"fastapi/gradio application {fastapi_title}, exception:{ex}.") | |
| raise ex | |